From 679d543f8e84a8290064a72319477c45084cb303 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Wed, 19 Nov 2025 16:26:23 -0800 Subject: [PATCH 01/63] docs: Update ADK triaging agent to only triage planned issues It also enables the ADK triaging agent to run periodically on planned but not triaged issues. Co-authored-by: Xuan Yang PiperOrigin-RevId: 834489103 --- .github/workflows/triage.yml | 7 +++- .../samples/adk_triaging_agent/agent.py | 37 ++++++++++++++++ .../samples/adk_triaging_agent/main.py | 42 +++++++++++++------ 3 files changed, 73 insertions(+), 13 deletions(-) diff --git a/.github/workflows/triage.yml b/.github/workflows/triage.yml index 57e729e9b5..0d310cee89 100644 --- a/.github/workflows/triage.yml +++ b/.github/workflows/triage.yml @@ -2,11 +2,16 @@ name: ADK Issue Triaging Agent on: issues: - types: [opened, reopened] + types: [labeled] + schedule: + # Run every 6 hours to triage planned but not triaged issues + - cron: '0 */6 * * *' jobs: agent-triage-issues: runs-on: ubuntu-latest + # Only run if labeled with "planned" or if it's a scheduled run + if: github.event_name == 'schedule' || github.event.label.name == 'planned' permissions: issues: write contents: read diff --git a/contributing/samples/adk_triaging_agent/agent.py b/contributing/samples/adk_triaging_agent/agent.py index 4ffcc35235..9504e72dff 100644 --- a/contributing/samples/adk_triaging_agent/agent.py +++ b/contributing/samples/adk_triaging_agent/agent.py @@ -110,6 +110,42 @@ def list_unlabeled_issues(issue_count: int) -> dict[str, Any]: return {"status": "success", "issues": unlabeled_issues} +def list_planned_untriaged_issues(issue_count: int) -> dict[str, Any]: + """List planned issues without component labels (e.g., core, tools, etc.). + + Args: + issue_count: number of issues to return + + Returns: + The status of this request, with a list of issues when successful. + """ + url = f"{GITHUB_BASE_URL}/search/issues" + query = f"repo:{OWNER}/{REPO} is:open is:issue label:planned" + params = { + "q": query, + "sort": "created", + "order": "desc", + "per_page": issue_count, + "page": 1, + } + + try: + response = get_request(url, params) + except requests.exceptions.RequestException as e: + return error_response(f"Error: {e}") + issues = response.get("items", []) + + # Filter out issues that already have component labels + component_labels = set(LABEL_TO_OWNER.keys()) + untriaged_issues = [] + for issue in issues: + issue_labels = {label["name"] for label in issue.get("labels", [])} + # If the issue only has "planned" but no component labels, it's untriaged + if not (issue_labels & component_labels): + untriaged_issues.append(issue) + return {"status": "success", "issues": untriaged_issues} + + def add_label_and_owner_to_issue( issue_number: int, label: str ) -> dict[str, Any]: @@ -241,6 +277,7 @@ def change_issue_type(issue_number: int, issue_type: str) -> dict[str, Any]: """, tools=[ list_unlabeled_issues, + list_planned_untriaged_issues, add_label_and_owner_to_issue, change_issue_type, ], diff --git a/contributing/samples/adk_triaging_agent/main.py b/contributing/samples/adk_triaging_agent/main.py index 317f5893e2..f608a696c0 100644 --- a/contributing/samples/adk_triaging_agent/main.py +++ b/contributing/samples/adk_triaging_agent/main.py @@ -16,6 +16,7 @@ import time from adk_triaging_agent import agent +from adk_triaging_agent.agent import LABEL_TO_OWNER from adk_triaging_agent.settings import EVENT_NAME from adk_triaging_agent.settings import GITHUB_BASE_URL from adk_triaging_agent.settings import ISSUE_BODY @@ -37,21 +38,32 @@ async def fetch_specific_issue_details(issue_number: int): - """Fetches details for a single issue if it's unlabelled.""" + """Fetches details for a single issue if it needs triaging.""" url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{issue_number}" print(f"Fetching details for specific issue: {url}") try: issue_data = get_request(url) - if not issue_data.get("labels", None): - print(f"Issue #{issue_number} is unlabelled. Proceeding.") + labels = issue_data.get("labels", []) + label_names = {label["name"] for label in labels} + + # Check if issue has "planned" label but no component labels + component_labels = set(LABEL_TO_OWNER.keys()) + has_planned = "planned" in label_names + has_component = bool(label_names & component_labels) + + if has_planned and not has_component: + print(f"Issue #{issue_number} is planned but not triaged. Proceeding.") return { "number": issue_data["number"], "title": issue_data["title"], "body": issue_data.get("body", ""), } else: - print(f"Issue #{issue_number} is already labelled. Skipping.") + print( + f"Issue #{issue_number} is already triaged or doesn't have" + " 'planned' label. Skipping." + ) return None except requests.exceptions.RequestException as e: print(f"Error fetching issue #{issue_number}: {e}") @@ -108,26 +120,32 @@ async def main(): specific_issue = await fetch_specific_issue_details(issue_number) if specific_issue is None: print( - f"No unlabelled issue details found for #{issue_number} or an error" - " occurred. Skipping agent interaction." + f"No issue details found for #{issue_number} that needs triaging," + " or an error occurred. Skipping agent interaction." ) return issue_title = ISSUE_TITLE or specific_issue["title"] issue_body = ISSUE_BODY or specific_issue["body"] prompt = ( - f"A new GitHub issue #{issue_number} has been opened or" - f' reopened. Title: "{issue_title}"\nBody:' + f"A GitHub issue #{issue_number} has been labeled as 'planned'." + f' Title: "{issue_title}"\nBody:' f' "{issue_body}"\n\nBased on the rules, recommend an' - " appropriate label and its justification." - " Then, use the 'add_label_to_issue' tool to apply the label " - "directly to this issue. Only label it, do not" + " appropriate component label and its justification." + " Then, use the 'add_label_and_owner_to_issue' tool to apply the" + " label directly to this issue. Only label it, do not" " process any other issues." ) else: print(f"EVENT: Processing batch of issues (event: {EVENT_NAME}).") issue_count = parse_number_string(ISSUE_COUNT_TO_PROCESS, default_value=3) - prompt = f"Please triage the most recent {issue_count} issues." + prompt = ( + "Please use the 'list_planned_untriaged_issues' tool to find the" + f" most recent {issue_count} planned issues that haven't been" + " triaged yet (i.e., issues with 'planned' label but no component" + " labels like 'core', 'tools', etc.). Then triage each of them by" + " applying appropriate component labels." + ) response = await call_agent_async(runner, USER_ID, session.id, prompt) print(f"<<<< Agent Final Output: {response}\n") From 8fc6128b62ba576480d196d4a2597564fd0a7006 Mon Sep 17 00:00:00 2001 From: Divyansh Shukla Date: Wed, 19 Nov 2025 16:35:48 -0800 Subject: [PATCH 02/63] fix: Fix out of bounds error in _run_async_impl PiperOrigin-RevId: 834492696 --- src/google/adk/agents/llm_agent.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 2f8a969fad..71a074881c 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -469,10 +469,7 @@ async def _run_async_impl( if ctx.is_resumable: events = ctx._get_events(current_invocation=True, current_branch=True) - if events and ( - ctx.should_pause_invocation(events[-1]) - or ctx.should_pause_invocation(events[-2]) - ): + if any(ctx.should_pause_invocation(e) for e in events[-2:]): return # Only yield an end state if the last event is no longer a long running # tool call. From a6e4d6c0d96a0d4684760d004a961f82579b9060 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Wed, 19 Nov 2025 17:12:20 -0800 Subject: [PATCH 03/63] chore: Bumps version to v1.19.0 and updates CHANGELOG.md Co-authored-by: Shangjie Chen PiperOrigin-RevId: 834503873 --- CHANGELOG.md | 100 ++++++++++++++++++++++++++++++++++++++ src/google/adk/version.py | 2 +- 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ced7b7026b..72e0c7b19f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,105 @@ # Changelog +## [1.19.0](https://github.com/google/adk-python/compare/v1.18.0...v1.19.0) (2025-11-19) + +### Features + +* **[Core]** + * Add `id` and `custom_metadata` fields to `MemoryEntry` ([4dd28a3](https://github.com/google/adk-python/commit/4dd28a3970d0f76c571caf80b3e1bea1b79e9dde)) + * Add progressive SSE streaming feature ([a5ac1d5](https://github.com/google/adk-python/commit/a5ac1d5e14f5ce7cd875d81a494a773710669dc1)) + * Add a2a_request_meta_provider to RemoteAgent init ([d12468e](https://github.com/google/adk-python/commit/d12468ee5a2b906b6699ccdb94c6a5a4c2822465)) + * Add feature decorator for the feature registry system ([871da73](https://github.com/google/adk-python/commit/871da731f1c09c6a62d51b137d9d2e7c9fb3897a)) + * Breaking: Raise minimum Python version to 3_10 ([8402832](https://github.com/google/adk-python/commit/840283228ee77fb3dbd737cfe7eb8736d9be5ec8)) + * Refactor and rename BigQuery agent analytics plugin ([6b14f88](https://github.com/google/adk-python/commit/6b14f887262722ccb85dcd6cef9c0e9b103cfa6e)) + * Pass custom_metadata through forwarding artifact service ([c642f13](https://github.com/google/adk-python/commit/c642f13f216fb64bc93ac46c1c57702c8a2add8c)) + * Update save_files_as_artifacts_plugin to never keep inline data ([857de04](https://github.com/google/adk-python/commit/857de04debdeba421075c2283c9bd8518d586624)) + +* **[Evals]** + * Add support for InOrder and AnyOrder match in ToolTrajectoryAvgScore Metric ([e2d3b2d](https://github.com/google/adk-python/commit/e2d3b2d862f7fc93807d16089307d4df25367a24)) + +* **[Integrations]** + * Enhance BQ Plugin Schema, Error Handling, and Logging ([5ac5129](https://github.com/google/adk-python/commit/5ac5129fb01913516d6f5348a825ca83d024d33a)) + * Schema Enhancements with Descriptions, Partitioning, and Truncation Indicator ([7c993b0](https://github.com/google/adk-python/commit/7c993b01d1b9d582b4e2348f73c0591d47bf2f3a)) + +* **[Services]** + * Add file-backed artifact service ([99ca6aa](https://github.com/google/adk-python/commit/99ca6aa6e6b4027f37d091d9c93da6486def20d7)) + * Add service factory for configurable session and artifact backends ([a12ae81](https://github.com/google/adk-python/commit/a12ae812d367d2d00ab246f85a73ed679dd3828a)) + * Add SqliteSessionService and a migration script to migrate existing DB using DatabaseSessionService to SqliteSessionService ([e218254](https://github.com/google/adk-python/commit/e2182544952c0174d1a8307fbba319456dca748b)) + * Add transcription fields to session events ([3ad30a5](https://github.com/google/adk-python/commit/3ad30a58f95b8729f369d00db799546069d7b23a)) + * Full async implementation of DatabaseSessionService ([7495941](https://github.com/google/adk-python/commit/74959414d8ded733d584875a49fb4638a12d3ce5)) + +* **[Models]** + * Add experimental feature to use `parameters_json_schema` and `response_json_schema` for McpTool ([1dd97f5](https://github.com/google/adk-python/commit/1dd97f5b45226c25e4c51455c78ebf3ff56ab46a)) + * Add support for parsing inline JSON tool calls in LiteLLM responses ([22eb7e5](https://github.com/google/adk-python/commit/22eb7e5b06c9e048da5bb34fe7ae9135d00acb4e)) + * Expose artifact URLs to the model when available ([e3caf79](https://github.com/google/adk-python/commit/e3caf791395ce3cc0b10410a852be6e7b0d8d3b1)) + +* **[Tools]** + * Add BigQuery related label handling ([ffbab4c](https://github.com/google/adk-python/commit/ffbab4cf4ed6ceb313241c345751214d3c0e11ce)) + * Allow setting max_billed_bytes in BigQuery tools config ([ffbb0b3](https://github.com/google/adk-python/commit/ffbb0b37e128de50ebf57d76cba8b743a8b970d5)) + * Propagate `application_name` set for the BigQuery Tools as BigQuery job labels ([f13a11e](https://github.com/google/adk-python/commit/f13a11e1dc27c5aa46345154fbe0eecfe1690cbb)) + * Set per-tool user agent in BQ calls and tool label in BQ jobs ([c0be1df](https://github.com/google/adk-python/commit/c0be1df0521cfd4b84585f404d4385b80d08ba59)) + +* **[Observability]** + * Migrate BigQuery logging to Storage Write API ([a2ce34a](https://github.com/google/adk-python/commit/a2ce34a0b9a8403f830ff637d0e2094e82dee8e7)) + +### Bug Fixes + +* Add `jsonschema` dependency for Agent Builder config validation ([0fa7e46](https://github.com/google/adk-python/commit/0fa7e4619d589dc834f7508a18bc2a3b93ec7fd9)) +* Add None check for `event` in `remote_a2a_agent.py` ([744f94f](https://github.com/google/adk-python/commit/744f94f0c8736087724205bbbad501640b365270)) +* Add vertexai initialization for code being deployed to AgentEngine ([b8e4aed](https://github.com/google/adk-python/commit/b8e4aedfbf0eb55b34599ee24e163b41072a699c)) +* Change LiteLLM content and tool parameter handling ([a19be12](https://github.com/google/adk-python/commit/a19be12c1f04bb62a8387da686499857c24b45c0)) +* Change name for builder agent ([131d39c](https://github.com/google/adk-python/commit/131d39c3db1ae25e3911fa7f72afbe05e24a1c37)) +* Ensure event compaction completes by awaiting task ([b5f5df9](https://github.com/google/adk-python/commit/b5f5df9fa8f616b855c186fcef45bade00653c77)) +* Fix deploy to cloud run on Windows ([29fea7e](https://github.com/google/adk-python/commit/29fea7ec1fb27989f07c90494b2d6acbe76c03d8)) +* Fix error handling when MCP server is unreachable ([ee8106b](https://github.com/google/adk-python/commit/ee8106be77f253e3687e72ae0e236687d254965c)) +* Fix error when query job destination is None ([0ccc43c](https://github.com/google/adk-python/commit/0ccc43cf49dc0882dc896455d6603a602d8a28e7)) +* Fix Improve logic for checking if a MCP session is disconnected ([a754c96](https://github.com/google/adk-python/commit/a754c96d3c4fd00f9c2cd924fc428b68cc5115fb)) +* Fix McpToolset crashing with anyio.BrokenResourceError ([8e0648d](https://github.com/google/adk-python/commit/8e0648df23d0694afd3e245ec4a3c41aa935120a)) +* Fix Safely handle `FunctionDeclaration` without a `required` attribute ([93aad61](https://github.com/google/adk-python/commit/93aad611983dc1daf415d3a73105db45bbdd1988)) +* Fix status code in error message in RestApiTool ([9b75456](https://github.com/google/adk-python/commit/9b754564b3cc5a06ad0c6ae2cd2d83082f9f5943)) +* Fix Use `async for` to loop through event iterator to get all events in vertex_ai_session_service ([9211f4c](https://github.com/google/adk-python/commit/9211f4ce8cc6d918df314d6a2ff13da2e0ef35fa)) +* Fix: Fixes DeprecationWarning when using send method ([2882995](https://github.com/google/adk-python/commit/28829952890c39dbdb4463b2b67ff241d0e9ef6d)) +* Improve logic for checking if a MCP session is disconnected ([a48a1a9](https://github.com/google/adk-python/commit/a48a1a9e889d4126e6f30b56c93718dfbacef624)) +* Improve handling of partial and complete transcriptions in live calls ([1819ecb](https://github.com/google/adk-python/commit/1819ecb4b8c009d02581c2d060fae49cd7fdf653)) +* Keep vertex session event after the session update time ([0ec0195](https://github.com/google/adk-python/commit/0ec01956e86df6ae8e6553c70e410f1f8238ba88)) +* Let part converters also return multiple parts so they can support more usecases ([824ab07](https://github.com/google/adk-python/commit/824ab072124e037cc373c493f43de38f8b61b534)) +* Load agent/app before creating session ([236f562](https://github.com/google/adk-python/commit/236f562cd275f84837be46f7dfb0065f85425169)) +* Remove app name from FileArtifactService directory structure ([12db84f](https://github.com/google/adk-python/commit/12db84f5cd6d8b6e06142f6f6411f6b78ff3f177)) +* Remove hardcoded `google-cloud-aiplatform` version in agent engine requirements ([e15e19d](https://github.com/google/adk-python/commit/e15e19da05ee1b763228467e83f6f73e0eced4b5)) +* Stop updating write mode in the global settings during tool execution ([5adbf95](https://github.com/google/adk-python/commit/5adbf95a0ab0657dd7df5c4a6bac109d424d436e)) +* Update description for `load_artifacts` tool ([c485889](https://github.com/google/adk-python/commit/c4858896ff085bedcfbc42b2010af8bd78febdd0)) + +### Improvements + +* Add BigQuery related label handling ([ffbab4c](https://github.com/google/adk-python/commit/ffbab4cf4ed6ceb313241c345751214d3c0e11ce)) +* Add demo for rewind ([8eb1bdb](https://github.com/google/adk-python/commit/8eb1bdbc58dc709006988f5b6eec5fda25bd0c89)) +* Add debug logging for live connection ([5d5708b](https://github.com/google/adk-python/commit/5d5708b2ab26cb714556311c490b4d6f0a1f9666)) +* Add debug logging for missing function call events ([f3d6fcf](https://github.com/google/adk-python/commit/f3d6fcf44411d07169c14ae12189543f44f96c27)) +* Add default retry options as fall back to llm_request that are made during evals ([696852a](https://github.com/google/adk-python/commit/696852a28095a024cbe76413ee7617356e19a9e3)) +* Add plugin for returning GenAI Parts from tools into the model request ([116b26c](https://github.com/google/adk-python/commit/116b26c33e166bf1a22964e2b67013907fbfcb80)) +* Add support for abstract types in AFC ([2efc184](https://github.com/google/adk-python/commit/2efc184a46173529bdfc622db0d6f3866e7ee778)) +* Add support for structured output schemas in LiteLLM models ([7ea4aed](https://github.com/google/adk-python/commit/7ea4aed35ba70ec5a38dc1b3b0a9808183c2bab1)) +* Add tests for `max_query_result_rows` in BigQuery tool config ([fd33610](https://github.com/google/adk-python/commit/fd33610e967ad814bc02422f5d14dae046bee833)) +* Add type hints in `cleanup_unused_files.py` ([2dea573](https://github.com/google/adk-python/commit/2dea5733b759a7a07d74f36a4d6da7b081afc732)) +* Add util to build our llms.txt and llms-full.txt files +* ADK changes ([f1f4467](https://github.com/google/adk-python/commit/f1f44675e4a86b75e72cfd838efd8a0399f23e24)) +* Defer import of `google.cloud.storage` in `GCSArtifactService` ([999af55](https://github.com/google/adk-python/commit/999af5588005e7b29451bdbf9252265187ca992d)) +* Defer import of `live`, `Client` and `_transformers` in `google.genai` ([22c6dbe](https://github.com/google/adk-python/commit/22c6dbe83cd1a8900d0ac6fd23d2092f095189fa)) +* Enhance the messaging with possible fixes for RESOURCE_EXHAUSTED errors from Gemini ([b2c45f8](https://github.com/google/adk-python/commit/b2c45f8d910eb7bca4805c567279e65aff72b58a)) +* Improve gepa tau-bench colab for external use ([e02f177](https://github.com/google/adk-python/commit/e02f177790d9772dd253c9102b80df1a9418aa7f)) +* Improve gepa voter agent demo colab ([d118479](https://github.com/google/adk-python/commit/d118479ccf3a970ce9b24ac834b4b6764edb5de4)) +* Lazy import DatabaseSessionService in the adk/sessions/ module ([5f05749](https://github.com/google/adk-python/commit/5f057498a274d3b3db0be0866f04d5225334f54a)) +* Move adk_agent_builder_assistant to built_in_agents ([b2b7f2d](https://github.com/google/adk-python/commit/b2b7f2d6aa5b919a00a92abaf2543993746e939e)) +* Plumb memory service from LocalEvalService to EvaluationGenerator ([dc3f60c](https://github.com/google/adk-python/commit/dc3f60cc939335da49399a69c0b4abc0e7f25ea4)) +* Removes the unrealistic todo comment of visibility management ([e511eb1](https://github.com/google/adk-python/commit/e511eb1f70f2a3fccc9464ddaf54d0165db22feb)) +* Returns agent state regardless if ctx.is_resumable ([d6b928b](https://github.com/google/adk-python/commit/d6b928bdf7cdbf8f1925d4c5227c7d580093348e)) +* Stop logging the full content of LLM blobs ([0826755](https://github.com/google/adk-python/commit/082675546f501a70f4bc8969b9431a2e4808bd13)) +* Update ADK web to match main branch ([14e3802](https://github.com/google/adk-python/commit/14e3802643a2d8ce436d030734fafd163080a1ad)) +* Update agent instructions and retry limit in `plugin_reflect_tool_retry` sample ([01bac62](https://github.com/google/adk-python/commit/01bac62f0c14cce5d454a389b64a9f44a03a3673)) +* Update conformance test CLI to handle long-running tool calls ([dd706bd](https://github.com/google/adk-python/commit/dd706bdc4563a2a815459482237190a63994cb6f)) +* Update Gemini Live model names in live bidi streaming sample ([aa77834](https://github.com/google/adk-python/commit/aa77834e2ecd4b77dfb4e689ef37549b3ebd6134)) + + ## [1.18.0](https://github.com/google/adk-python/compare/v1.17.0...v1.18.0) (2025-11-05) ### Features diff --git a/src/google/adk/version.py b/src/google/adk/version.py index 0a21522cb6..9478b2a547 100644 --- a/src/google/adk/version.py +++ b/src/google/adk/version.py @@ -13,4 +13,4 @@ # limitations under the License. # version: major.minor.patch -__version__ = "1.18.0" +__version__ = "1.19.0" From a3aa07722a7de3e08807e86fd10f28938f0b267d Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Wed, 19 Nov 2025 17:50:12 -0800 Subject: [PATCH 04/63] fix: Update the retry_on_closed_resource decorator to retry on all errors Retrying only on closed_resource error is not enough to be reliable for production environments due to the other network errors that may occur -- remote protocol error, read timeout, etc. We will update this to retry on all errors. Since it is only a one-time retry, it should not affect latency significantly. Fixes https://github.com/google/adk-python/issues/2561. Co-authored-by: Kathy Wu PiperOrigin-RevId: 834514264 --- .../adk/tools/mcp_tool/mcp_session_manager.py | 16 ++++++++-------- src/google/adk/tools/mcp_tool/mcp_tool.py | 4 ++-- src/google/adk/tools/mcp_tool/mcp_toolset.py | 4 ++-- .../tools/mcp_tool/test_mcp_session_manager.py | 10 +++++----- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index d95d48f282..7d9714aada 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -108,10 +108,10 @@ class StreamableHTTPConnectionParams(BaseModel): terminate_on_close: bool = True -def retry_on_closed_resource(func): - """Decorator to automatically retry action when MCP session is closed. +def retry_on_errors(func): + """Decorator to automatically retry action when MCP session errors occur. - When MCP session was closed, the decorator will automatically retry the + When MCP session errors occur, the decorator will automatically retry the action once. The create_session method will handle creating a new session if the old one was disconnected. @@ -126,11 +126,11 @@ def retry_on_closed_resource(func): async def wrapper(self, *args, **kwargs): try: return await func(self, *args, **kwargs) - except (anyio.ClosedResourceError, anyio.BrokenResourceError): - # If the session connection is closed or unusable, we will retry the - # function to reconnect to the server. create_session will handle - # detecting and replacing disconnected sessions. - logger.info('Retrying %s due to closed resource', func.__name__) + except Exception as e: + # If an error is thrown, we will retry the function to reconnect to the + # server. create_session will handle detecting and replacing disconnected + # sessions. + logger.info('Retrying %s due to error: %s', func.__name__, e) return await func(self, *args, **kwargs) return wrapper diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index ad420a3d0f..284aea4105 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -34,7 +34,7 @@ from ...features import is_feature_enabled from .._gemini_schema_util import _to_gemini_schema from .mcp_session_manager import MCPSessionManager -from .mcp_session_manager import retry_on_closed_resource +from .mcp_session_manager import retry_on_errors # Attempt to import MCP Tool from the MCP library, and hints user to upgrade # their Python version to 3.10 if it fails. @@ -195,7 +195,7 @@ async def run_async( return {"error": "This tool call is rejected."} return await super().run_async(args=args, tool_context=tool_context) - @retry_on_closed_resource + @retry_on_errors @override async def _run_async_impl( self, *, args, tool_context: ToolContext, credential: AuthCredential diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index daa88f9031..3768477e1d 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -37,7 +37,7 @@ from ..tool_configs import BaseToolConfig from ..tool_configs import ToolArgsConfig from .mcp_session_manager import MCPSessionManager -from .mcp_session_manager import retry_on_closed_resource +from .mcp_session_manager import retry_on_errors from .mcp_session_manager import SseConnectionParams from .mcp_session_manager import StdioConnectionParams from .mcp_session_manager import StreamableHTTPConnectionParams @@ -155,7 +155,7 @@ def __init__( self._auth_credential = auth_credential self._require_confirmation = require_confirmation - @retry_on_closed_resource + @retry_on_errors async def get_tools( self, readonly_context: Optional[ReadonlyContext] = None, diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index 6c001ccf65..b2d6b1cb88 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -32,7 +32,7 @@ # Import dependencies with version checking try: from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager - from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_closed_resource + from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_errors from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams @@ -44,7 +44,7 @@ class DummyClass: pass MCPSessionManager = DummyClass - retry_on_closed_resource = lambda x: x + retry_on_errors = lambda x: x SseConnectionParams = DummyClass StdioConnectionParams = DummyClass StreamableHTTPConnectionParams = DummyClass @@ -375,12 +375,12 @@ async def test_close_with_errors(self): assert "Close error 1" in error_output -def test_retry_on_closed_resource_decorator(): - """Test the retry_on_closed_resource decorator.""" +def test_retry_on_errors_decorator(): + """Test the retry_on_errors decorator.""" call_count = 0 - @retry_on_closed_resource + @retry_on_errors async def mock_function(self): nonlocal call_count call_count += 1 From caf23ac49fe08bc7f625c61eed4635c26852c3ba Mon Sep 17 00:00:00 2001 From: Holt Skinner <13262395+holtskinner@users.noreply.github.com> Date: Thu, 20 Nov 2025 09:59:52 -0800 Subject: [PATCH 05/63] docs: Add Code Wiki badge to README Merge https://github.com/google/adk-python/pull/3603 **Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.** ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** - Closes: #_issue_number_ - Related: #_issue_number_ **2. Or, if no issue exists, describe the change:** _If applicable, please follow the issue templates to provide as much detail as possible._ **Problem:** _A clear and concise description of what the problem is._ **Solution:** _A clear and concise description of what you want to happen and why you choose this solution._ ### Testing Plan _Please describe the tests that you ran to verify your changes. This is required for all PRs that are not small documentation or typo fixes._ **Unit Tests:** - [ ] I have added or updated unit tests for my change. - [ ] All unit tests pass locally. _Please include a summary of passed `pytest` results._ **Manual End-to-End (E2E) Tests:** _Please provide instructions on how to manually test your changes, including any necessary setup or configuration. Please provide logs or screenshots to help reviewers better understand the fix._ ### Checklist - [ ] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [ ] I have performed a self-review of my own code. - [ ] I have commented my code, particularly in hard-to-understand areas. - [ ] I have added tests that prove my fix is effective or that my feature works. - [ ] New and existing unit tests pass locally with my changes. - [ ] I have manually tested my changes end-to-end. - [ ] Any dependent changes have been merged and published in downstream modules. ### Additional context _Add any other context or screenshots about the feature request here._ Co-authored-by: Hangfei Lin COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3603 from holtskinner:patch-1 833405898b747438e3f8a1b8c34095e25135fca3 PiperOrigin-RevId: 834805195 --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 3113b1ccdf..9046cc49a3 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![PyPI](https://img.shields.io/pypi/v/google-adk)](https://pypi.org/project/google-adk/) [![Python Unit Tests](https://github.com/google/adk-python/actions/workflows/python-unit-tests.yml/badge.svg)](https://github.com/google/adk-python/actions/workflows/python-unit-tests.yml) [![r/agentdevelopmentkit](https://img.shields.io/badge/Reddit-r%2Fagentdevelopmentkit-FF4500?style=flat&logo=reddit&logoColor=white)](https://www.reddit.com/r/agentdevelopmentkit/) -[![Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/google/adk-python) +Ask Code Wiki

From bf8b85da52ac6903acfc589e478cf698d43443c1 Mon Sep 17 00:00:00 2001 From: George Weale Date: Thu, 20 Nov 2025 11:41:11 -0800 Subject: [PATCH 06/63] fix: save sessions with camelCase aliases Make sure that the adk run --save_session writes session JSON using the Pydantic camelCase aliases (by_alias=True), matching ADK Web and keeping session files consistent Close #3558 Co-authored-by: George Weale PiperOrigin-RevId: 834847209 --- src/google/adk/cli/cli.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index ed294d3922..5ae18aac0a 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -218,6 +218,8 @@ async def run_cli( session_id=session.id, ) with open(session_path, 'w', encoding='utf-8') as f: - f.write(session.model_dump_json(indent=2, exclude_none=True)) + f.write( + session.model_dump_json(indent=2, exclude_none=True, by_alias=True) + ) print('Session saved to', session_path) From 084c2de0dac84697906e2b4beebf008bbd9ae8e1 Mon Sep 17 00:00:00 2001 From: George Weale Date: Thu, 20 Nov 2025 11:48:12 -0800 Subject: [PATCH 07/63] fix: Make sure request bodies without explicit names are named 'body' The `Parameter` class now provides default Python names based on the parameter location when the original name is empty. This prevents parameters from having an empty string as their Python name, especially for request bodies defined without a top-level name. Close #2213 Co-authored-by: George Weale PiperOrigin-RevId: 834850255 --- .../adk/tools/openapi_tool/common/common.py | 18 +++++++--- .../openapi_spec_parser/operation_parser.py | 13 +++++-- .../tools/openapi_tool/common/test_common.py | 18 ++++++++++ .../test_operation_parser.py | 34 +++++++++++++++++++ 4 files changed, 76 insertions(+), 7 deletions(-) diff --git a/src/google/adk/tools/openapi_tool/common/common.py b/src/google/adk/tools/openapi_tool/common/common.py index 7187b1bd1b..1df3125e3d 100644 --- a/src/google/adk/tools/openapi_tool/common/common.py +++ b/src/google/adk/tools/openapi_tool/common/common.py @@ -64,11 +64,9 @@ class ApiParameter(BaseModel): required: bool = False def model_post_init(self, _: Any): - self.py_name = ( - self.py_name - if self.py_name - else rename_python_keywords(_to_snake_case(self.original_name)) - ) + if not self.py_name: + inferred_name = rename_python_keywords(_to_snake_case(self.original_name)) + self.py_name = inferred_name or self._default_py_name() if isinstance(self.param_schema, str): self.param_schema = Schema.model_validate_json(self.param_schema) @@ -77,6 +75,16 @@ def model_post_init(self, _: Any): self.type_hint = TypeHintHelper.get_type_hint(self.param_schema) return self + def _default_py_name(self) -> str: + location_defaults = { + 'body': 'body', + 'query': 'query_param', + 'path': 'path_param', + 'header': 'header_param', + 'cookie': 'cookie_param', + } + return location_defaults.get(self.param_location or '', 'value') + @model_serializer def _serialize(self): return { diff --git a/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py b/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py index 06d692a2b0..326ff6787e 100644 --- a/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +++ b/src/google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py @@ -139,10 +139,19 @@ def _process_request_body(self): ) ) else: + # Prefer explicit body name to avoid empty keys when schema lacks type + # information (e.g., oneOf/anyOf/allOf) while retaining legacy behavior + # for simple scalar types. + if schema and (schema.oneOf or schema.anyOf or schema.allOf): + param_name = 'body' + elif not schema or not schema.type: + param_name = 'body' + else: + param_name = '' + self._params.append( - # Empty name for unnamed body param ApiParameter( - original_name='', + original_name=param_name, param_location='body', param_schema=schema, description=description, diff --git a/tests/unittests/tools/openapi_tool/common/test_common.py b/tests/unittests/tools/openapi_tool/common/test_common.py index 47aeb79fdb..faece5be89 100644 --- a/tests/unittests/tools/openapi_tool/common/test_common.py +++ b/tests/unittests/tools/openapi_tool/common/test_common.py @@ -74,6 +74,24 @@ def test_api_parameter_keyword_rename(self): ) assert param.py_name == 'param_in' + def test_api_parameter_uses_location_default_when_name_missing(self): + schema = Schema(type='string') + param = ApiParameter( + original_name='', + param_location='body', + param_schema=schema, + ) + assert param.py_name == 'body' + + def test_api_parameter_uses_value_default_when_location_unknown(self): + schema = Schema(type='integer') + param = ApiParameter( + original_name='', + param_location='', + param_schema=schema, + ) + assert param.py_name == 'value' + def test_api_parameter_custom_py_name(self): schema = Schema(type='integer') param = ApiParameter( diff --git a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py index 26cb944a22..83741c97a2 100644 --- a/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py +++ b/tests/unittests/tools/openapi_tool/openapi_spec_parser/test_operation_parser.py @@ -164,6 +164,40 @@ def test_process_request_body_no_name(): assert parser._params[0].param_location == 'body' +def test_process_request_body_one_of_schema_assigns_name(): + """Ensures oneOf bodies result in a named parameter.""" + operation = Operation( + operationId='one_of_request', + requestBody=RequestBody( + content={ + 'application/json': MediaType( + schema=Schema( + oneOf=[ + Schema( + type='object', + properties={ + 'type': Schema(type='string'), + 'stage': Schema(type='string'), + }, + ) + ], + discriminator={'propertyName': 'type'}, + ) + ) + } + ), + responses={'200': Response(description='ok')}, + ) + parser = OperationParser(operation) + params = parser.get_parameters() + assert len(params) == 1 + assert params[0].original_name == 'body' + assert params[0].py_name == 'body' + schema = parser.get_json_schema() + assert 'body' in schema['properties'] + assert '' not in schema['properties'] + + def test_process_request_body_empty_object(): """Test _process_request_body with a schema that is of type object but with no properties.""" operation = Operation( From cd54f48fed0c87b54fb19743c9c75e790c5d9135 Mon Sep 17 00:00:00 2001 From: Issac Date: Thu, 20 Nov 2025 12:26:11 -0800 Subject: [PATCH 08/63] fix: fix paths for public docs Merge https://github.com/google/adk-python/pull/3572 **Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.** ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** N/A **2. Or, if no issue exists, describe the change:** **Problem:** Docs fix ### Checklist - [ ] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [ ] I have performed a self-review of my own code. - [ ] I have commented my code, particularly in hard-to-understand areas. - [ ] I have added tests that prove my fix is effective or that my feature works. - [ ] New and existing unit tests pass locally with my changes. - [ ] I have manually tested my changes end-to-end. - [ ] Any dependent changes have been merged and published in downstream modules. Co-authored-by: Hangfei Lin COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3572 from issacg:patch-1 b7c7ed46ff0fb018f4da1537535eff27c323daf5 PiperOrigin-RevId: 834864431 --- contributing/samples/computer_use/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/contributing/samples/computer_use/README.md b/contributing/samples/computer_use/README.md index ff7ae6c0a5..38b6fe79c6 100644 --- a/contributing/samples/computer_use/README.md +++ b/contributing/samples/computer_use/README.md @@ -19,7 +19,7 @@ The computer use agent consists of: Install the required Python packages from the requirements file: ```bash -uv pip install -r internal/samples/computer_use/requirements.txt +uv pip install -r contributing/samples/computer_use/requirements.txt ``` ### 2. Install Playwright Dependencies @@ -45,7 +45,7 @@ playwright install chromium To start the computer use agent, run the following command from the project root: ```bash -adk web internal/samples +adk web contributing/samples ``` This will start the ADK web interface where you can interact with the computer_use agent. From b57fe5f4598925ec7592917bb32c7f0d6eca287a Mon Sep 17 00:00:00 2001 From: Dylan <114695692+dylan-apex@users.noreply.github.com> Date: Thu, 20 Nov 2025 14:14:09 -0800 Subject: [PATCH 09/63] feat(web): add list-apps-detailed endpoint Merge https://github.com/google/adk-python/pull/3430 **Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.** ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** - Closes: #3429 **2. Or, if no issue exists, describe the change:** _If applicable, please follow the issue templates to provide as much detail as possible._ **Problem:** The existing `/list-apps` endpoint only returns the name of the folder that each agent is in **Solution:** This adds a new endpoint `/list-apps-detailed` which will load each agent using the existing `AgentLoader.load_agent` method, and then return the folder name, display name (with underscores replaced with spaces for a more readable version), description, and the agent type. This does introduce overhead if you had multiple agents since they all need to be loaded, but by maintaining the existing `/list-apps` endpoint, users can choose which one to hit if they don't want to load all agents. Since the existing `load_agents` method will cache results, there's only a penalty on the first hit. ### Testing Plan Created a unit test for this, similar to the `/list-apps`. Also tested this with my own ADK instance to verify it loaded correctly. ``` curl --location "localhost:8000/list-apps-detailed" ``` ```json { "apps": [ { "name": "agent_1", "displayName": "Agent 1", "description": "A test description for a test agent", "agentType": "package" }, { "name": "agent_2", "displayName": "Agent 2", "description": "A test description for a test agent ", "agentType": "package" }, { "name": "agent_3", "displayName": "Agent 3", "description": "A test description for a test agent", "agentType": "package" } ] } ``` **Unit Tests:** - [X] I have added or updated unit tests for my change. - [X] All unit tests pass locally. 3054 passed, 2383 warnings in 46.96s ### Checklist - [X] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [X] I have performed a self-review of my own code. - [X] I have commented my code, particularly in hard-to-understand areas. - [X] I have added tests that prove my fix is effective or that my feature works. - [X] New and existing unit tests pass locally with my changes. - [X] I have manually tested my changes end-to-end. - [X] Any dependent changes have been merged and published in downstream modules. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3430 from dylan-apex:more-detailed-list-apps e6864fd61a673da5fd2fb28d2d7d72cb90f5af0a PiperOrigin-RevId: 834907771 --- src/google/adk/cli/adk_web_server.py | 20 +++++++- src/google/adk/cli/utils/agent_loader.py | 46 +++++++++++++++++++ src/google/adk/cli/utils/base_agent_loader.py | 13 ++++++ tests/unittests/cli/test_fast_api.py | 28 +++++++++++ 4 files changed, 106 insertions(+), 1 deletion(-) diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 1b422fe335..45747a52a1 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -280,6 +280,17 @@ class ListMetricsInfoResponse(common.BaseModel): metrics_info: list[MetricInfo] +class AppInfo(common.BaseModel): + name: str + root_agent_name: str + description: str + language: Literal["yaml", "python"] + + +class ListAppsResponse(common.BaseModel): + apps: list[AppInfo] + + def _setup_telemetry( otel_to_cloud: bool = False, internal_exporters: Optional[list[SpanProcessor]] = None, @@ -699,7 +710,14 @@ async def internal_lifespan(app: FastAPI): ) @app.get("/list-apps") - async def list_apps() -> list[str]: + async def list_apps( + detailed: bool = Query( + default=False, description="Return detailed app information" + ) + ) -> list[str] | ListAppsResponse: + if detailed: + apps_info = self.agent_loader.list_agents_detailed() + return ListAppsResponse(apps=[AppInfo(**app) for app in apps_info]) return self.agent_loader.list_agents() @app.get("/debug/trace/{event_id}", tags=[TAG_DEBUG]) diff --git a/src/google/adk/cli/utils/agent_loader.py b/src/google/adk/cli/utils/agent_loader.py index 9f01705d4f..9661df6fe8 100644 --- a/src/google/adk/cli/utils/agent_loader.py +++ b/src/google/adk/cli/utils/agent_loader.py @@ -20,6 +20,8 @@ import os from pathlib import Path import sys +from typing import Any +from typing import Literal from typing import Optional from typing import Union @@ -341,6 +343,50 @@ def list_agents(self) -> list[str]: agent_names.sort() return agent_names + def list_agents_detailed(self) -> list[dict[str, Any]]: + """Lists all agents with detailed metadata (name, description, type).""" + agent_names = self.list_agents() + apps_info = [] + + for agent_name in agent_names: + try: + loaded = self.load_agent(agent_name) + if isinstance(loaded, App): + agent = loaded.root_agent + else: + agent = loaded + + language = self._determine_agent_language(agent_name) + + app_info = { + "name": agent_name, + "root_agent_name": agent.name, + "description": agent.description, + "language": language, + } + apps_info.append(app_info) + + except Exception as e: + logger.error("Failed to load agent '%s': %s", agent_name, e) + continue + + return apps_info + + def _determine_agent_language( + self, agent_name: str + ) -> Literal["yaml", "python"]: + """Determine the type of agent based on file structure.""" + base_path = Path.cwd() / self.agents_dir / agent_name + + if (base_path / "root_agent.yaml").exists(): + return "yaml" + elif (base_path / "agent.py").exists(): + return "python" + elif (base_path / "__init__.py").exists(): + return "python" + + raise ValueError(f"Could not determine agent type for '{agent_name}'.") + def remove_agent_from_cache(self, agent_name: str): # Clear module cache for the agent and its submodules keys_to_delete = [ diff --git a/src/google/adk/cli/utils/base_agent_loader.py b/src/google/adk/cli/utils/base_agent_loader.py index d62a6b8651..bcef0dae42 100644 --- a/src/google/adk/cli/utils/base_agent_loader.py +++ b/src/google/adk/cli/utils/base_agent_loader.py @@ -18,6 +18,7 @@ from abc import ABC from abc import abstractmethod +from typing import Any from typing import Union from ...agents.base_agent import BaseAgent @@ -34,3 +35,15 @@ def load_agent(self, agent_name: str) -> Union[BaseAgent, App]: @abstractmethod def list_agents(self) -> list[str]: """Lists all agents available in the agent loader in alphabetical order.""" + + def list_agents_detailed(self) -> list[dict[str, Any]]: + agent_names = self.list_agents() + return [ + { + 'name': name, + 'display_name': None, + 'description': None, + 'type': None, + } + for name in agent_names + ] diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 2d7b9472ba..d50bfcd8e5 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -190,6 +190,14 @@ def load_agent(self, app_name): def list_agents(self): return ["test_app"] + def list_agents_detailed(self): + return [{ + "name": "test_app", + "root_agent_name": "test_agent", + "description": "A test agent for unit testing", + "language": "python", + }] + return MockAgentLoader(".") @@ -548,6 +556,26 @@ def test_list_apps(test_app): logger.info(f"Listed apps: {data}") +def test_list_apps_detailed(test_app): + """Test listing available applications with detailed metadata.""" + response = test_app.get("/list-apps?detailed=true") + + assert response.status_code == 200 + data = response.json() + assert isinstance(data, dict) + assert "apps" in data + assert isinstance(data["apps"], list) + + for app in data["apps"]: + assert "name" in app + assert "rootAgentName" in app + assert "description" in app + assert "language" in app + assert app["language"] in ["yaml", "python"] + + logger.info(f"Listed apps: {data}") + + def test_create_session_with_id(test_app, test_session_info): """Test creating a session with a specific ID.""" new_session_id = "new_session_id" From 31cfa3b82bff2a130622d3ba0909024927121ce4 Mon Sep 17 00:00:00 2001 From: George Weale Date: Thu, 20 Nov 2025 16:29:49 -0800 Subject: [PATCH 10/63] feat: Capture thinking output, forward raw payloads, and fix exec locals MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LlmResponse/Event now keep both provider reasoning output and the raw vendor payload so callbacks and loggers can inspect hidden “thoughts” or trace bugs without rewriting adapters. LiteLLM’s adapter and streaming loop emit reasoning chunks alongside text and aggregate them into final events -> all responses now carry a JSON-safe copy of the source payload for debug. UnsafeLocalCodeExecutor uses the documented exec(code, globals, globals) form, letting helper functions defined inside snippets call each other. Close #1749 Co-authored-by: George Weale PiperOrigin-RevId: 834956847 --- .../unsafe_local_code_executor.py | 2 +- src/google/adk/models/lite_llm.py | 132 ++++++++++++++++-- .../test_unsafe_local_code_executor.py | 20 +++ tests/unittests/models/test_litellm.py | 86 +++++------- 4 files changed, 175 insertions(+), 65 deletions(-) diff --git a/src/google/adk/code_executors/unsafe_local_code_executor.py b/src/google/adk/code_executors/unsafe_local_code_executor.py index 6dd2ae9d8c..b47fbd17e9 100644 --- a/src/google/adk/code_executors/unsafe_local_code_executor.py +++ b/src/google/adk/code_executors/unsafe_local_code_executor.py @@ -72,7 +72,7 @@ def execute_code( _prepare_globals(code_execution_input.code, globals_) stdout = io.StringIO() with redirect_stdout(stdout): - exec(code_execution_input.code, globals_) + exec(code_execution_input.code, globals_, globals_) output = stdout.getvalue() except Exception as e: error = str(e) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index c263a41b2a..e83e7efdb4 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -96,6 +96,64 @@ def _decode_inline_text_data(raw_bytes: bytes) -> str: return raw_bytes.decode("latin-1", errors="replace") +def _iter_reasoning_texts(reasoning_value: Any) -> Iterable[str]: + """Yields textual fragments from provider specific reasoning payloads.""" + if reasoning_value is None: + return + + if isinstance(reasoning_value, types.Content): + if not reasoning_value.parts: + return + for part in reasoning_value.parts: + if part and part.text: + yield part.text + return + + if isinstance(reasoning_value, str): + yield reasoning_value + return + + if isinstance(reasoning_value, list): + for value in reasoning_value: + yield from _iter_reasoning_texts(value) + return + + if isinstance(reasoning_value, dict): + # LiteLLM currently nests “reasoning” text under a few known keys. + # (Documented in https://docs.litellm.ai/docs/openai#reasoning-outputs) + for key in ("text", "content", "reasoning", "reasoning_content"): + text_value = reasoning_value.get(key) + if isinstance(text_value, str): + yield text_value + return + + text_attr = getattr(reasoning_value, "text", None) + if isinstance(text_attr, str): + yield text_attr + elif isinstance(reasoning_value, (int, float, bool)): + yield str(reasoning_value) + + +def _convert_reasoning_value_to_parts(reasoning_value: Any) -> List[types.Part]: + """Converts provider reasoning payloads into Gemini thought parts.""" + return [ + types.Part(text=text, thought=True) + for text in _iter_reasoning_texts(reasoning_value) + if text + ] + + +def _extract_reasoning_value(message: Message | Dict[str, Any]) -> Any: + """Fetches the reasoning payload from a LiteLLM message or dict.""" + if message is None: + return None + if hasattr(message, "reasoning_content"): + return getattr(message, "reasoning_content") + if isinstance(message, dict): + return message.get("reasoning_content") + return None + + class ChatCompletionFileUrlObject(TypedDict, total=False): file_data: str file_id: str @@ -113,6 +171,10 @@ class TextChunk(BaseModel): text: str +class ReasoningChunk(BaseModel): + parts: List[types.Part] + + class UsageMetadataChunk(BaseModel): prompt_tokens: int completion_tokens: int @@ -660,7 +722,6 @@ def _function_declaration_to_tool_param( }, } - # Handle required field from parameters required_fields = ( getattr(function_declaration.parameters, "required", None) if function_declaration.parameters @@ -668,8 +729,6 @@ def _function_declaration_to_tool_param( ) if required_fields: tool_params["function"]["parameters"]["required"] = required_fields - # parameters_json_schema already has required field in the json schema, - # no need to add it separately return tool_params @@ -678,7 +737,14 @@ def _model_response_to_chunk( response: ModelResponse, ) -> Generator[ Tuple[ - Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]], + Optional[ + Union[ + TextChunk, + FunctionChunk, + UsageMetadataChunk, + ReasoningChunk, + ] + ], Optional[str], ], None, @@ -703,11 +769,18 @@ def _model_response_to_chunk( message_content: Optional[OpenAIMessageContent] = None tool_calls: list[ChatCompletionMessageToolCall] = [] + reasoning_parts: List[types.Part] = [] if message is not None: ( message_content, tool_calls, ) = _split_message_content_and_tool_calls(message) + reasoning_value = _extract_reasoning_value(message) + if reasoning_value: + reasoning_parts = _convert_reasoning_value_to_parts(reasoning_value) + + if reasoning_parts: + yield ReasoningChunk(parts=reasoning_parts), finish_reason if message_content: yield TextChunk(text=message_content), finish_reason @@ -771,8 +844,13 @@ def _model_response_to_generate_content_response( if not message: raise ValueError("No message in response") + thought_parts = _convert_reasoning_value_to_parts( + _extract_reasoning_value(message) + ) llm_response = _message_to_generate_content_response( - message, model_version=response.model + message, + model_version=response.model, + thought_parts=thought_parts or None, ) if finish_reason: # If LiteLLM already provides a FinishReason enum (e.g., for Gemini), use @@ -797,7 +875,11 @@ def _model_response_to_generate_content_response( def _message_to_generate_content_response( - message: Message, *, is_partial: bool = False, model_version: str = None + message: Message, + *, + is_partial: bool = False, + model_version: str = None, + thought_parts: Optional[List[types.Part]] = None, ) -> LlmResponse: """Converts a litellm message to LlmResponse. @@ -810,7 +892,13 @@ def _message_to_generate_content_response( The LlmResponse. """ - parts = [] + parts: List[types.Part] = [] + if not thought_parts: + thought_parts = _convert_reasoning_value_to_parts( + _extract_reasoning_value(message) + ) + if thought_parts: + parts.extend(thought_parts) message_content, tool_calls = _split_message_content_and_tool_calls(message) if isinstance(message_content, str) and message_content: parts.append(types.Part.from_text(text=message_content)) @@ -972,15 +1060,9 @@ def _build_function_declaration_log( k: v.model_dump(exclude_none=True) for k, v in func_decl.parameters.properties.items() }) - elif func_decl.parameters_json_schema: - param_str = str(func_decl.parameters_json_schema) - return_str = "None" if func_decl.response: return_str = str(func_decl.response.model_dump(exclude_none=True)) - elif func_decl.response_json_schema: - return_str = str(func_decl.response_json_schema) - return f"{func_decl.name}: {param_str} -> {return_str}" @@ -1182,6 +1264,7 @@ async def generate_content_async( if stream: text = "" + reasoning_parts: List[types.Part] = [] # Track function calls by index function_calls = {} # index -> {name, args, id} completion_args["stream"] = True @@ -1223,6 +1306,14 @@ async def generate_content_async( is_partial=True, model_version=part.model, ) + elif isinstance(chunk, ReasoningChunk): + if chunk.parts: + reasoning_parts.extend(chunk.parts) + yield LlmResponse( + content=types.Content(role="model", parts=list(chunk.parts)), + partial=True, + model_version=part.model, + ) elif isinstance(chunk, UsageMetadataChunk): usage_metadata = types.GenerateContentResponseUsageMetadata( prompt_token_count=chunk.prompt_tokens, @@ -1256,16 +1347,27 @@ async def generate_content_async( tool_calls=tool_calls, ), model_version=part.model, + thought_parts=list(reasoning_parts) + if reasoning_parts + else None, ) ) text = "" + reasoning_parts = [] function_calls.clear() - elif finish_reason == "stop" and text: + elif finish_reason == "stop" and (text or reasoning_parts): + message_content = text if text else None aggregated_llm_response = _message_to_generate_content_response( - ChatCompletionAssistantMessage(role="assistant", content=text), + ChatCompletionAssistantMessage( + role="assistant", content=message_content + ), model_version=part.model, + thought_parts=list(reasoning_parts) + if reasoning_parts + else None, ) text = "" + reasoning_parts = [] # waiting until streaming ends to yield the llm_response as litellm tends # to send chunk that contains usage_metadata after the chunk with diff --git a/tests/unittests/code_executors/test_unsafe_local_code_executor.py b/tests/unittests/code_executors/test_unsafe_local_code_executor.py index eeb10b34fa..e5d5c4f792 100644 --- a/tests/unittests/code_executors/test_unsafe_local_code_executor.py +++ b/tests/unittests/code_executors/test_unsafe_local_code_executor.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import textwrap from unittest.mock import MagicMock from google.adk.agents.base_agent import BaseAgent @@ -101,3 +102,22 @@ def test_execute_code_empty(self, mock_invocation_context: InvocationContext): result = executor.execute_code(mock_invocation_context, code_input) assert result.stdout == "" assert result.stderr == "" + + def test_execute_code_nested_function_call( + self, mock_invocation_context: InvocationContext + ): + executor = UnsafeLocalCodeExecutor() + code_input = CodeExecutionInput(code=(textwrap.dedent(""" + def helper(name): + return f'hi {name}' + + def run(): + print(helper('ada')) + + run() + """))) + + result = executor.execute_code(mock_invocation_context, code_input) + + assert result.stderr == "" + assert result.stdout == "hi ada\n" diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index 8f2ae50b42..fd3983fb76 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -17,7 +17,6 @@ from unittest.mock import Mock import warnings -from google.adk.models.lite_llm import _build_function_declaration_log from google.adk.models.lite_llm import _content_to_message_param from google.adk.models.lite_llm import _FINISH_REASON_MAPPING from google.adk.models.lite_llm import _function_declaration_to_tool_param @@ -25,6 +24,7 @@ from google.adk.models.lite_llm import _get_content from google.adk.models.lite_llm import _message_to_generate_content_response from google.adk.models.lite_llm import _model_response_to_chunk +from google.adk.models.lite_llm import _model_response_to_generate_content_response from google.adk.models.lite_llm import _parse_tool_calls_from_text from google.adk.models.lite_llm import _split_message_content_and_tool_calls from google.adk.models.lite_llm import _to_litellm_response_format @@ -630,54 +630,6 @@ def completion(self, model, messages, tools, stream, **kwargs): ) -def test_build_function_declaration_log(): - """Test that _build_function_declaration_log formats function declarations correctly.""" - # Test case 1: Function with parameters and response - func_decl1 = types.FunctionDeclaration( - name="test_func1", - description="Test function 1", - parameters=types.Schema( - type=types.Type.OBJECT, - properties={ - "param1": types.Schema( - type=types.Type.STRING, description="param1 desc" - ) - }, - ), - response=types.Schema(type=types.Type.BOOLEAN, description="return bool"), - ) - log1 = _build_function_declaration_log(func_decl1) - assert log1 == ( - "test_func1: {'param1': {'description': 'param1 desc', 'type':" - " }} -> {'description': 'return bool', 'type':" - " }" - ) - - # Test case 2: Function with JSON schema parameters and response - func_decl2 = types.FunctionDeclaration( - name="test_func2", - description="Test function 2", - parameters_json_schema={ - "type": "object", - "properties": {"param2": {"type": "integer"}}, - }, - response_json_schema={"type": "string"}, - ) - log2 = _build_function_declaration_log(func_decl2) - assert log2 == ( - "test_func2: {'type': 'object', 'properties': {'param2': {'type':" - " 'integer'}}} -> {'type': 'string'}" - ) - - # Test case 3: Function with no parameters and no response - func_decl3 = types.FunctionDeclaration( - name="test_func3", - description="Test function 3", - ) - log3 = _build_function_declaration_log(func_decl3) - assert log3 == "test_func3: {} -> None" - - @pytest.mark.asyncio async def test_generate_content_async(mock_acompletion, lite_llm_instance): @@ -1535,6 +1487,42 @@ def test_message_to_generate_content_response_with_model(): assert response.model_version == "gemini-2.5-pro" +def test_message_to_generate_content_response_reasoning_content(): + message = { + "role": "assistant", + "content": "Visible text", + "reasoning_content": "Hidden chain", + } + response = _message_to_generate_content_response(message) + + assert len(response.content.parts) == 2 + thought_part = response.content.parts[0] + text_part = response.content.parts[1] + assert thought_part.text == "Hidden chain" + assert thought_part.thought is True + assert text_part.text == "Visible text" + + +def test_model_response_to_generate_content_response_reasoning_content(): + model_response = ModelResponse( + model="thinking-model", + choices=[{ + "message": { + "role": "assistant", + "content": "Answer", + "reasoning_content": "Step-by-step", + }, + "finish_reason": "stop", + }], + ) + + response = _model_response_to_generate_content_response(model_response) + + assert response.content.parts[0].text == "Step-by-step" + assert response.content.parts[0].thought is True + assert response.content.parts[1].text == "Answer" + + def test_parse_tool_calls_from_text_multiple_calls(): text = ( '{"name":"alpha","arguments":{"value":1}}\n' From 848fdbef7c73036c6900bf4808b892762bb0c074 Mon Sep 17 00:00:00 2001 From: George Weale Date: Thu, 20 Nov 2025 17:02:10 -0800 Subject: [PATCH 11/63] fix: Filter out None values from enum lists in schema generation Close #3552 Co-authored-by: George Weale PiperOrigin-RevId: 834967220 --- src/google/adk/models/lite_llm.py | 44 ++++++++++++-------------- tests/unittests/models/test_litellm.py | 23 ++++++++++++++ 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index e83e7efdb4..c9712974ea 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -632,10 +632,8 @@ def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]: } -def _schema_to_dict(schema: types.Schema) -> dict: - """Recursively converts a types.Schema to a pure-python dict - - with all enum values written as lower-case strings. +def _schema_to_dict(schema: types.Schema | dict[str, Any]) -> dict: + """Recursively converts a schema object or dict to a pure-python dict. Args: schema: The schema to convert. @@ -643,38 +641,36 @@ def _schema_to_dict(schema: types.Schema) -> dict: Returns: The dictionary representation of the schema. """ - # Dump without json encoding so we still get Enum members - schema_dict = schema.model_dump(exclude_none=True) + schema_dict = ( + schema.model_dump(exclude_none=True) + if isinstance(schema, types.Schema) + else dict(schema) + ) + enum_values = schema_dict.get("enum") + if isinstance(enum_values, (list, tuple)): + schema_dict["enum"] = [value for value in enum_values if value is not None] - # ---- normalise this level ------------------------------------------------ - if "type" in schema_dict: - # schema_dict["type"] can be an Enum or a str + if "type" in schema_dict and schema_dict["type"] is not None: t = schema_dict["type"] - schema_dict["type"] = (t.value if isinstance(t, types.Type) else t).lower() + schema_dict["type"] = ( + t.value if isinstance(t, types.Type) else str(t) + ).lower() - # ---- recurse into `items` ----------------------------------------------- if "items" in schema_dict: - schema_dict["items"] = _schema_to_dict( - schema.items - if isinstance(schema.items, types.Schema) - else types.Schema.model_validate(schema_dict["items"]) + items = schema_dict["items"] + schema_dict["items"] = ( + _schema_to_dict(items) + if isinstance(items, (types.Schema, dict)) + else items ) - # ---- recurse into `properties` ------------------------------------------ if "properties" in schema_dict: new_props = {} for key, value in schema_dict["properties"].items(): - # value is a dict → rebuild a Schema object and recurse - if isinstance(value, dict): - new_props[key] = _schema_to_dict(types.Schema.model_validate(value)) - # value is already a Schema instance - elif isinstance(value, types.Schema): + if isinstance(value, (types.Schema, dict)): new_props[key] = _schema_to_dict(value) - # plain dict without nested schemas else: new_props[key] = value - if "type" in new_props[key]: - new_props[key]["type"] = new_props[key]["type"].lower() schema_dict["properties"] = new_props return schema_dict diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index fd3983fb76..c486806d37 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -26,6 +26,7 @@ from google.adk.models.lite_llm import _model_response_to_chunk from google.adk.models.lite_llm import _model_response_to_generate_content_response from google.adk.models.lite_llm import _parse_tool_calls_from_text +from google.adk.models.lite_llm import _schema_to_dict from google.adk.models.lite_llm import _split_message_content_and_tool_calls from google.adk.models.lite_llm import _to_litellm_response_format from google.adk.models.lite_llm import _to_litellm_role @@ -286,6 +287,28 @@ def test_to_litellm_response_format_handles_genai_schema_instance(): ) +def test_schema_to_dict_filters_none_enum_values(): + # Use model_construct to bypass strict enum validation. + top_level_schema = types.Schema.model_construct( + type=types.Type.STRING, + enum=["ACTIVE", None, "INACTIVE"], + ) + nested_schema = types.Schema.model_construct( + type=types.Type.OBJECT, + properties={ + "status": types.Schema.model_construct( + type=types.Type.STRING, enum=["READY", None, "DONE"] + ), + }, + ) + + assert _schema_to_dict(top_level_schema)["enum"] == ["ACTIVE", "INACTIVE"] + assert _schema_to_dict(nested_schema)["properties"]["status"]["enum"] == [ + "READY", + "DONE", + ] + + MULTIPLE_FUNCTION_CALLS_STREAM = [ ModelResponse( choices=[ From a3e4ad3cd130714affcaa880f696aeb498cd93af Mon Sep 17 00:00:00 2001 From: George Weale Date: Thu, 20 Nov 2025 18:26:31 -0800 Subject: [PATCH 12/63] fix: Update session last update time when appending events The `database_session_service` now updates the `update_time` of a session to the event's timestamp when an event is appended Close #2721 Co-authored-by: George Weale PiperOrigin-RevId: 834994070 --- .../adk/sessions/database_session_service.py | 5 ++ .../adk/sessions/sqlite_session_service.py | 23 ++++-- .../sessions/test_session_service.py | 73 +++++++++++++++++++ 3 files changed, 94 insertions(+), 7 deletions(-) diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index b929f23409..a352918211 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -722,6 +722,11 @@ async def append_event(self, session: Session, event: Event) -> Event: if session_state_delta: storage_session.state = storage_session.state | session_state_delta + if storage_session._dialect_name == "sqlite": + update_time = datetime.utcfromtimestamp(event.timestamp) + else: + update_time = datetime.fromtimestamp(event.timestamp) + storage_session.update_time = update_time sql_session.add(StorageEvent.from_event(session, event)) await sql_session.commit() diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index 10d05f6dfd..8ba6531f52 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -323,7 +323,7 @@ async def append_event(self, session: Session, event: Event) -> Event: # Trim temp state before persisting event = self._trim_temp_delta_state(event) - now = time.time() + event_timestamp = event.timestamp async with self._get_db_connection() as db: # Check for stale session @@ -355,11 +355,15 @@ async def append_event(self, session: Session, event: Event) -> Event: if app_state_delta: await self._upsert_app_state( - db, session.app_name, app_state_delta, now + db, session.app_name, app_state_delta, event_timestamp ) if user_state_delta: await self._upsert_user_state( - db, session.app_name, session.user_id, user_state_delta, now + db, + session.app_name, + session.user_id, + user_state_delta, + event_timestamp, ) if session_state_delta: await self._update_session_state_in_db( @@ -368,7 +372,7 @@ async def append_event(self, session: Session, event: Event) -> Event: session.user_id, session.id, session_state_delta, - now, + event_timestamp, ) has_session_state_delta = True @@ -392,12 +396,17 @@ async def append_event(self, session: Session, event: Event) -> Event: await db.execute( "UPDATE sessions SET update_time=? WHERE app_name=? AND user_id=?" " AND id=?", - (now, session.app_name, session.user_id, session.id), + ( + event_timestamp, + session.app_name, + session.user_id, + session.id, + ), ) await db.commit() - # Update timestamp with commit time - session.last_update_time = now + # Update timestamp based on event time + session.last_update_time = event_timestamp # Also update the in-memory session await super().append_event(session=session, event=event) diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 661e6ead59..45aa3feede 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -531,6 +531,79 @@ async def test_append_event_complete(service_type, tmp_path): ) +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'service_type', + [ + SessionServiceType.IN_MEMORY, + SessionServiceType.DATABASE, + SessionServiceType.SQLITE, + ], +) +async def test_session_last_update_time_updates_on_event( + service_type, tmp_path +): + session_service = get_session_service(service_type, tmp_path) + app_name = 'my_app' + user_id = 'user' + + session = await session_service.create_session( + app_name=app_name, user_id=user_id + ) + original_update_time = session.last_update_time + + event_timestamp = original_update_time + 10 + event = Event( + invocation_id='invocation', + author='user', + timestamp=event_timestamp, + ) + await session_service.append_event(session=session, event=event) + + assert session.last_update_time == pytest.approx(event_timestamp, abs=1e-6) + + refreshed_session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert refreshed_session is not None + assert refreshed_session.last_update_time == pytest.approx( + event_timestamp, abs=1e-6 + ) + assert refreshed_session.last_update_time > original_update_time + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'service_type', + [ + SessionServiceType.IN_MEMORY, + SessionServiceType.DATABASE, + SessionServiceType.SQLITE, + ], +) +async def test_get_session_with_config(service_type): + session_service = get_session_service(service_type) + app_name = 'my_app' + user_id = 'user' + + session = await session_service.create_session( + app_name=app_name, user_id=user_id + ) + original_update_time = session.last_update_time + + event = Event(invocation_id='invocation', author='user') + await session_service.append_event(session=session, event=event) + + assert session.last_update_time >= event.timestamp + + refreshed_session = await session_service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert refreshed_session is not None + assert refreshed_session.last_update_time >= event.timestamp + assert refreshed_session.last_update_time > original_update_time + + @pytest.mark.asyncio @pytest.mark.parametrize( 'service_type', From 631b58336d36bfd93e190582be34069613d38559 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 20 Nov 2025 22:18:54 -0800 Subject: [PATCH 13/63] fix: Content is marked non empty if its first part contains text or inline_data or file_data or func call/response PiperOrigin-RevId: 835063599 --- src/google/adk/flows/llm_flows/contents.py | 12 +++- .../flows/llm_flows/test_contents.py | 56 +++++++++++++++++++ 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index 9274cd462d..fefa014c45 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -224,7 +224,8 @@ def _contains_empty_content(event: Event) -> bool: This can happen to the events that only changed session state. When both content and transcriptions are empty, the event will be considered - as empty. + as empty. The content is considered empty if none of its parts contain text, + inline data, file data, function call, or function response. Args: event: The event to check. @@ -239,7 +240,14 @@ def _contains_empty_content(event: Event) -> bool: not event.content or not event.content.role or not event.content.parts - or event.content.parts[0].text == '' + or all( + not p.text + and not p.inline_data + and not p.file_data + and not p.function_call + and not p.function_response + for p in [event.content.parts[0]] + ) ) and (not event.output_transcription and not event.input_transcription) diff --git a/tests/unittests/flows/llm_flows/test_contents.py b/tests/unittests/flows/llm_flows/test_contents.py index cf55630b67..b2aa91dbee 100644 --- a/tests/unittests/flows/llm_flows/test_contents.py +++ b/tests/unittests/flows/llm_flows/test_contents.py @@ -465,6 +465,43 @@ async def test_events_with_empty_content_are_skipped(): author="user", content=types.UserContent("How are you?"), ), + # Event with content that has only empty text part + Event( + invocation_id="inv6", + author="user", + content=types.Content(parts=[types.Part(text="")], role="model"), + ), + # Event with content that has only inline data part + Event( + invocation_id="inv7", + author="user", + content=types.Content( + parts=[ + types.Part( + inline_data=types.Blob( + data=b"test", mime_type="image/png" + ) + ) + ], + role="user", + ), + ), + # Event with content that has only file data part + Event( + invocation_id="inv8", + author="user", + content=types.Content( + parts=[ + types.Part( + file_data=types.FileData( + file_uri="gs://test_bucket/test_file.png", + mime_type="image/png", + ) + ) + ], + role="user", + ), + ), ] invocation_context.session.events = events @@ -478,4 +515,23 @@ async def test_events_with_empty_content_are_skipped(): assert llm_request.contents == [ types.UserContent("Hello"), types.UserContent("How are you?"), + types.Content( + parts=[ + types.Part( + inline_data=types.Blob(data=b"test", mime_type="image/png") + ) + ], + role="user", + ), + types.Content( + parts=[ + types.Part( + file_data=types.FileData( + file_uri="gs://test_bucket/test_file.png", + mime_type="image/png", + ) + ) + ], + role="user", + ), ] From 59eba96ea4cff2ed303572327334a0820a6c25d1 Mon Sep 17 00:00:00 2001 From: Austin Wise Date: Fri, 21 Nov 2025 09:09:37 -0800 Subject: [PATCH 14/63] fix: Remove unused, incorrect import PiperOrigin-RevId: 835247166 --- src/google/adk/cli/service_registry.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/google/adk/cli/service_registry.py b/src/google/adk/cli/service_registry.py index 9f23b73035..521a3f9f7d 100644 --- a/src/google/adk/cli/service_registry.py +++ b/src/google/adk/cli/service_registry.py @@ -76,7 +76,6 @@ def my_session_factory(uri: str, **kwargs): from ..artifacts.base_artifact_service import BaseArtifactService from ..memory.base_memory_service import BaseMemoryService -from ..sessions import InMemorySessionService from ..sessions.base_session_service import BaseSessionService from ..utils import yaml_utils From 11df1e886d74c6e207aefafb175d8bd8246a3d3b Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 21 Nov 2025 10:11:49 -0800 Subject: [PATCH 15/63] fix: Change `pass` to `yield` in BaseLlmConnection.receive PiperOrigin-RevId: 835268090 --- src/google/adk/models/base_llm_connection.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/google/adk/models/base_llm_connection.py b/src/google/adk/models/base_llm_connection.py index afce550b13..1bf522740e 100644 --- a/src/google/adk/models/base_llm_connection.py +++ b/src/google/adk/models/base_llm_connection.py @@ -72,7 +72,8 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: Yields: LlmResponse: The model response. """ - pass + # We need to yield here to help type checkers infer the correct type. + yield @abstractmethod async def close(self): From a4453c884cd6f4bd6f5ddcd25839657e02ab4f9d Mon Sep 17 00:00:00 2001 From: Max Ind Date: Fri, 21 Nov 2025 10:17:08 -0800 Subject: [PATCH 16/63] fix: adk deploy agent_engine uses correct URI during an update Co-authored-by: Max Ind PiperOrigin-RevId: 835269976 --- src/google/adk/cli/cli_deploy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/cli/cli_deploy.py b/src/google/adk/cli/cli_deploy.py index 8da65d0e21..45dce7fda6 100644 --- a/src/google/adk/cli/cli_deploy.py +++ b/src/google/adk/cli/cli_deploy.py @@ -915,7 +915,7 @@ def to_agent_engine( ) else: if project and region and not agent_engine_id.startswith('projects/'): - agent_engine_id = f'projects/{project}/locations/{region}/agentEngines/{agent_engine_id}' + agent_engine_id = f'projects/{project}/locations/{region}/reasoningEngines/{agent_engine_id}' client.agent_engines.update(name=agent_engine_id, config=agent_config) click.secho(f'✅ Updated agent engine: {agent_engine_id}', fg='green') finally: From 89aee16f166d5512d10f04600c9865d5858e0be6 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Fri, 21 Nov 2025 11:18:06 -0800 Subject: [PATCH 17/63] chore: Allow google-cloud-storage >=2.18.0 Resolves https://github.com/google/adk-python/issues/3641 Co-authored-by: Shangjie Chen PiperOrigin-RevId: 835292931 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4f8e42bcf7..7151d661db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ dependencies = [ "google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool "google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database "google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription - "google-cloud-storage>=3.0.0, <4.0.0", # For GCS Artifact service + "google-cloud-storage>=2.18.0, <4.0.0", # For GCS Artifact service "google-genai>=1.45.0, <2.0.0", # Google GenAI SDK "graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering "jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation From 5583bb819b76738f259aa792dd57c5585da9d252 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Fri, 21 Nov 2025 11:43:04 -0800 Subject: [PATCH 18/63] chore: Update MCP requirement to >1.10.0 Resolves https://github.com/google/adk-python/issues/3644 Co-authored-by: Shangjie Chen PiperOrigin-RevId: 835302097 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7151d661db..5c0515d6b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "google-genai>=1.45.0, <2.0.0", # Google GenAI SDK "graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering "jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation - "mcp>=1.8.0, <2.0.0", # For MCP Toolset + "mcp>=1.10.0, <2.0.0", # For MCP Toolset "opentelemetry-api>=1.37.0, <=1.37.0", # OpenTelemetry - limit upper version for sdk and api to not risk breaking changes from unstable _logs package. "opentelemetry-exporter-gcp-logging>=1.9.0a0, <2.0.0", "opentelemetry-exporter-gcp-monitoring>=1.9.0a0, <2.0.0", From 23f1d8914afc0c913877bae42270d903453f571e Mon Sep 17 00:00:00 2001 From: Rohit Yanamadala Date: Fri, 21 Nov 2025 12:53:49 -0800 Subject: [PATCH 19/63] docs(agent): Implement stale issue bot Merge https://github.com/google/adk-python/pull/3546 Co-authored-by: Xuan Yang COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3546 from ryanaiagent:feat/stale-issue-agent bcf45098c1c6406b4a42228e4a8ef02f12840425 PiperOrigin-RevId: 835327931 --- .github/workflows/stale-bot.yml | 57 +++ .../adk_stale_agent/PROMPT_INSTRUCTION.txt | 40 ++ .../samples/adk_stale_agent/README.md | 65 +++ .../samples/adk_stale_agent/__init__.py | 15 + contributing/samples/adk_stale_agent/agent.py | 434 ++++++++++++++++++ contributing/samples/adk_stale_agent/main.py | 74 +++ .../samples/adk_stale_agent/settings.py | 49 ++ contributing/samples/adk_stale_agent/utils.py | 59 +++ 8 files changed, 793 insertions(+) create mode 100644 .github/workflows/stale-bot.yml create mode 100644 contributing/samples/adk_stale_agent/PROMPT_INSTRUCTION.txt create mode 100644 contributing/samples/adk_stale_agent/README.md create mode 100644 contributing/samples/adk_stale_agent/__init__.py create mode 100644 contributing/samples/adk_stale_agent/agent.py create mode 100644 contributing/samples/adk_stale_agent/main.py create mode 100644 contributing/samples/adk_stale_agent/settings.py create mode 100644 contributing/samples/adk_stale_agent/utils.py diff --git a/.github/workflows/stale-bot.yml b/.github/workflows/stale-bot.yml new file mode 100644 index 0000000000..882cb7b432 --- /dev/null +++ b/.github/workflows/stale-bot.yml @@ -0,0 +1,57 @@ +# .github/workflows/stale-issue-auditor.yml + +# Best Practice: Always have a 'name' field at the top. +name: ADK Stale Issue Auditor + +# The 'on' block defines the triggers. +on: + # The 'workflow_dispatch' trigger allows manual runs. + workflow_dispatch: + + # The 'schedule' trigger runs the bot on a timer. + schedule: + # This runs at 6:00 AM UTC (e.g., 10 PM PST). + - cron: '0 6 * * *' + +# The 'jobs' block contains the work to be done. +jobs: + # A unique ID for the job. + audit-stale-issues: + # The runner environment. + runs-on: ubuntu-latest + + # Permissions for the job's temporary GITHUB_TOKEN. + # These are standard and syntactically correct. + permissions: + issues: write + contents: read + + # The sequence of steps for the job. + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install dependencies + # The '|' character allows for multi-line shell commands. + run: | + python -m pip install --upgrade pip + pip install requests google-adk + + - name: Run Auditor Agent Script + # The 'env' block for setting environment variables. + env: + GITHUB_TOKEN: ${{ secrets.ADK_TRIAGE_AGENT }} + GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }} + OWNER: google + REPO: adk-python + ISSUES_PER_RUN: 100 + LLM_MODEL_NAME: "gemini-2.5-flash" + PYTHONPATH: contributing/samples + + # The final 'run' command. + run: python -m adk_stale_agent.main \ No newline at end of file diff --git a/contributing/samples/adk_stale_agent/PROMPT_INSTRUCTION.txt b/contributing/samples/adk_stale_agent/PROMPT_INSTRUCTION.txt new file mode 100644 index 0000000000..bb31889b23 --- /dev/null +++ b/contributing/samples/adk_stale_agent/PROMPT_INSTRUCTION.txt @@ -0,0 +1,40 @@ +You are a highly intelligent and transparent repository auditor for '{OWNER}/{REPO}'. +Your job is to analyze all open issues and report on your findings before taking any action. + +**Primary Directive:** Ignore any events from users ending in `[bot]`. +**Reporting Directive:** For EVERY issue you analyze, you MUST output a concise, human-readable summary, starting with "Analysis for Issue #[number]:". + +**WORKFLOW:** +1. **Context Gathering**: Call `get_repository_maintainers` and `get_all_open_issues`. +2. **Per-Issue Analysis**: For each issue, call `get_issue_state`, passing in the maintainers list. +3. **Decision & Reporting**: Based on the summary from `get_issue_state`, follow this strict decision tree in order. + +--- **DECISION TREE & REPORTING TEMPLATES** --- + +**STEP 1: CHECK FOR ACTIVITY (IS THE ISSUE ACTIVE?)** +- **Condition**: Was the last human action NOT from a maintainer? (i.e., `last_human_commenter_is_maintainer` is `False`). +- **Action**: The author or a third party has acted. The issue is ACTIVE. + - **Report and Action**: If '{STALE_LABEL_NAME}' is present, report: "Analysis for Issue #[number]: Issue is ACTIVE. The last action was a [action type] by a non-maintainer. To get the [action type], you MUST use the value from the 'last_human_action_type' field in the summary you received from the tool." Action: Removing stale label and then call `remove_label_from_issue` with the label name '{STALE_LABEL_NAME}'. Otherwise, report: "Analysis for Issue #[number]: Issue is ACTIVE. No stale label to remove. Action: None." +- **If this condition is met, stop processing this issue.** + +**STEP 2: IF PENDING, MANAGE THE STALE LIFECYCLE.** +- **Condition**: The last human action WAS from a maintainer (`last_human_commenter_is_maintainer` is `True`). The issue is PENDING. +- **Action**: You must now determine the correct state. + + - **First, check if the issue is already STALE.** + - **Condition**: Is the `'{STALE_LABEL_NAME}'` label present in `current_labels`? + - **Action**: The issue is STALE. Your only job is to check if it should be closed. + - **Get Time Difference**: Call `calculate_time_difference` with the `stale_label_applied_at` timestamp. + - **Decision & Report**: If `hours_passed` > **{CLOSE_HOURS_AFTER_STALE_THRESHOLD}**: Report "Analysis for Issue #[number]: STALE. Close threshold met ({CLOSE_HOURS_AFTER_STALE_THRESHOLD} hours) with no author activity." Action: Closing issue and then call `close_as_stale`. Otherwise, report "Analysis for Issue #[number]: STALE. Close threshold not yet met. Action: None." + + - **ELSE (the issue is PENDING but not yet stale):** + - **Analyze Intent**: Semantically analyze the `last_maintainer_comment_text`. Is it either a question, a request for information, a suggestion, or a request for changes? + - **If YES (it is either a question, a request for information, a suggestion, or a request for changes)**: + - **CRITICAL CHECK**: Now, you must verify the author has not already responded. Compare the `last_author_event_time` and the `last_maintainer_comment_time`. + - **IF the author has NOT responded** (i.e., `last_author_event_time` is older than `last_maintainer_comment_time` or is null): + - **Get Time Difference**: Call `calculate_time_difference` with the `last_maintainer_comment_time`. + - **Decision & Report**: If `hours_passed` > **{STALE_HOURS_THRESHOLD}**: Report "Analysis for Issue #[number]: PENDING. Stale threshold met ({STALE_HOURS_THRESHOLD} hours)." Action: Marking as stale and then call `add_stale_label_and_comment` and if label name '{REQUEST_CLARIFICATION_LABEL}' is missing then call `add_label_to_issue` with the label name '{REQUEST_CLARIFICATION_LABEL}'. Otherwise, report: "Analysis for Issue #[number]: PENDING. Stale threshold not met. Action: None." + - **ELSE (the author HAS responded)**: + - **Report**: "Analysis for Issue #[number]: PENDING, but author has already responded to the last maintainer request. Action: None." + - **If NO (it is not a request):** + - **Report**: "Analysis for Issue #[number]: PENDING. Maintainer's last comment was not a request. Action: None." \ No newline at end of file diff --git a/contributing/samples/adk_stale_agent/README.md b/contributing/samples/adk_stale_agent/README.md new file mode 100644 index 0000000000..17b427d77c --- /dev/null +++ b/contributing/samples/adk_stale_agent/README.md @@ -0,0 +1,65 @@ +# ADK Stale Issue Auditor Agent + +This directory contains an autonomous agent designed to audit a GitHub repository for stale issues, helping to maintain repository hygiene and ensure that all open items are actionable. + +The agent operates as a "Repository Auditor," proactively scanning all open issues rather than waiting for a specific trigger. It uses a combination of deterministic Python tools and the semantic understanding of a Large Language Model (LLM) to make intelligent decisions about the state of a conversation. + +--- + +## Core Logic & Features + +The agent's primary goal is to identify issues where a maintainer has requested information from the author, and to manage the lifecycle of that issue based on the author's response (or lack thereof). + +**The agent follows a precise decision tree:** + +1. **Audits All Open Issues:** On each run, the agent fetches a batch of the oldest open issues in the repository. +2. **Identifies Pending Issues:** It analyzes the full timeline of each issue to see if the last human action was a comment from a repository maintainer. +3. **Semantic Intent Analysis:** If the last comment was from a maintainer, the agent uses the LLM to determine if the comment was a **question or a request for clarification**. +4. **Marks as Stale:** If the maintainer's question has gone unanswered by the author for a configurable period (e.g., 7 days), the agent will: + * Apply a `stale` label to the issue. + * Post a comment notifying the author that the issue is now considered stale and will be closed if no further action is taken. + * Proactively add a `request clarification` label if it's missing, to make the issue's state clear. +5. **Handles Activity:** If any non-maintainer (the author or a third party) comments on an issue, the agent will automatically remove the `stale` label, marking the issue as active again. +6. **Closes Stale Issues:** If an issue remains in the `stale` state for another configurable period (e.g., 7 days) with no new activity, the agent will post a final comment and close the issue. + +### Self-Configuration + +A key feature of this agent is its ability to self-configure. It does not require a hard-coded list of maintainer usernames. On each run, it uses the GitHub API to dynamically fetch the list of users with write access to the repository, ensuring its logic is always based on the current team. + +--- + +## Configuration + +The agent is configured entirely via environment variables, which should be set as secrets in the GitHub Actions workflow environment. + +### Required Secrets + +| Secret Name | Description | +| :--- | :--- | +| `GITHUB_TOKEN` | A GitHub Personal Access Token (PAT) with the required permissions. It's recommended to use a PAT from a dedicated "bot" account. +| `GOOGLE_API_KEY` | An API key for the Google AI (Gemini) model used for the agent's reasoning. + +### Required PAT Permissions + +The `GITHUB_TOKEN` requires the following **Repository Permissions**: +* **Issues**: `Read & write` (to read issues, add labels, comment, and close) +* **Administration**: `Read-only` (to read the list of repository collaborators/maintainers) + +### Optional Configuration + +These environment variables can be set in the workflow file to override the defaults in `settings.py`. + +| Variable Name | Description | Default | +| :--- | :--- | :--- | +| `STALE_HOURS_THRESHOLD` | The number of hours of inactivity after a maintainer's question before an issue is marked as `stale`. | `168` (7 days) | +| `CLOSE_HOURS_AFTER_STALE_THRESHOLD` | The number of hours after being marked `stale` before an issue is closed. | `168` (7 days) | +| `ISSUES_PER_RUN` | The maximum number of oldest open issues to process in a single workflow run. | `100` | +| `LLM_MODEL_NAME`| LLM model to use. | `gemini-2.5-flash` | + +--- + +## Deployment + +To deploy this agent, a GitHub Actions workflow file (`.github/workflows/stale-bot.yml`) is included. This workflow runs on a daily schedule and executes the agent's main script. + +Ensure the necessary repository secrets are configured and the `stale` and `request clarification` labels exist in the repository. \ No newline at end of file diff --git a/contributing/samples/adk_stale_agent/__init__.py b/contributing/samples/adk_stale_agent/__init__.py new file mode 100644 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/adk_stale_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/adk_stale_agent/agent.py b/contributing/samples/adk_stale_agent/agent.py new file mode 100644 index 0000000000..abcb128288 --- /dev/null +++ b/contributing/samples/adk_stale_agent/agent.py @@ -0,0 +1,434 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +from datetime import timezone +import logging +import os +from typing import Any + +from adk_stale_agent.settings import CLOSE_HOURS_AFTER_STALE_THRESHOLD +from adk_stale_agent.settings import GITHUB_BASE_URL +from adk_stale_agent.settings import ISSUES_PER_RUN +from adk_stale_agent.settings import LLM_MODEL_NAME +from adk_stale_agent.settings import OWNER +from adk_stale_agent.settings import REPO +from adk_stale_agent.settings import REQUEST_CLARIFICATION_LABEL +from adk_stale_agent.settings import STALE_HOURS_THRESHOLD +from adk_stale_agent.settings import STALE_LABEL_NAME +from adk_stale_agent.utils import delete_request +from adk_stale_agent.utils import error_response +from adk_stale_agent.utils import get_request +from adk_stale_agent.utils import patch_request +from adk_stale_agent.utils import post_request +import dateutil.parser +from google.adk.agents.llm_agent import Agent +from requests.exceptions import RequestException + +logger = logging.getLogger("google_adk." + __name__) + +# --- Primary Tools for the Agent --- + + +def load_prompt_template(filename: str) -> str: + """Loads the prompt text file from the same directory as this script. + + Args: + filename: The name of the prompt file to load. + + Returns: + The content of the file as a string. + """ + file_path = os.path.join(os.path.dirname(__file__), filename) + + with open(file_path, "r") as f: + return f.read() + + +PROMPT_TEMPLATE = load_prompt_template("PROMPT_INSTRUCTION.txt") + + +def get_repository_maintainers() -> dict[str, Any]: + """ + Fetches the list of repository collaborators with 'push' (write) access or higher. + This should only be called once per run. + + Returns: + A dictionary with the status and a list of maintainer usernames, or an + error dictionary. + """ + logger.debug("Fetching repository maintainers with push access...") + try: + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/collaborators" + params = {"permission": "push"} + collaborators_data = get_request(url, params) + + maintainers = [user["login"] for user in collaborators_data] + logger.info(f"Found {len(maintainers)} repository maintainers.") + logger.debug(f"Maintainer list: {maintainers}") + + return {"status": "success", "maintainers": maintainers} + except RequestException as e: + logger.error(f"Failed to fetch repository maintainers: {e}", exc_info=True) + return error_response(f"Error fetching repository maintainers: {e}") + + +def get_all_open_issues() -> dict[str, Any]: + """Fetches a batch of the oldest open issues for an audit. + + Returns: + A dictionary containing the status and a list of open issues, or an error + dictionary. + """ + logger.info( + f"Fetching a batch of {ISSUES_PER_RUN} oldest open issues for audit..." + ) + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues" + params = { + "state": "open", + "sort": "created", + "direction": "asc", + "per_page": ISSUES_PER_RUN, + } + try: + items = get_request(url, params) + logger.info(f"Found {len(items)} open issues to audit.") + return {"status": "success", "items": items} + except RequestException as e: + logger.error(f"Failed to fetch open issues: {e}", exc_info=True) + return error_response(f"Error fetching all open issues: {e}") + + +def get_issue_state(item_number: int, maintainers: list[str]) -> dict[str, Any]: + """Analyzes an issue's complete history to create a comprehensive state summary. + + This function acts as the primary "detective" for the agent. It performs the + complex, deterministic work of fetching and parsing an issue's full history, + allowing the LLM agent to focus on high-level semantic decision-making. + + It is designed to be highly robust by fetching the complete, multi-page history + from the GitHub `/timeline` API. By handling pagination correctly, it ensures + that even issues with a very long history (more than 100 events) are analyzed + in their entirety, preventing incorrect decisions based on incomplete data. + + Args: + item_number (int): The number of the GitHub issue or pull request to analyze. + maintainers (list[str]): A dynamically fetched list of GitHub usernames to be + considered maintainers. This is used to categorize actors found in + the issue's history. + + Returns: + A dictionary that serves as a clean, factual report summarizing the + issue's state. On failure, it returns a dictionary with an 'error' status. + """ + try: + # Fetch core issue data and prepare for timeline fetching. + logger.debug(f"Fetching full timeline for issue #{item_number}...") + issue_url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}" + issue_data = get_request(issue_url) + + # Fetch All pages from the timeline API to build a complete history. + timeline_url_base = f"{issue_url}/timeline" + timeline_data = [] + page = 1 + + while True: + paginated_url = f"{timeline_url_base}?per_page=100&page={page}" + logger.debug(f"Fetching timeline page {page} for issue #{item_number}...") + events_page = get_request(paginated_url) + if not events_page: + break + timeline_data.extend(events_page) + if len(events_page) < 100: + break + page += 1 + + logger.debug( + f"Fetched a total of {len(timeline_data)} timeline events across" + f" {page-1} page(s) for issue #{item_number}." + ) + + # Initialize key variables for the analysis. + issue_author = issue_data.get("user", {}).get("login") + current_labels = [label["name"] for label in issue_data.get("labels", [])] + + # Filter and sort all events into a clean, chronological history of human activity. + human_events = [] + for event in timeline_data: + actor = event.get("actor", {}).get("login") + timestamp_str = event.get("created_at") or event.get("submitted_at") + + if not actor or not timestamp_str or actor.endswith("[bot]"): + continue + + event["parsed_time"] = dateutil.parser.isoparse(timestamp_str) + human_events.append(event) + + human_events.sort(key=lambda e: e["parsed_time"]) + + # Find the most recent, relevant events by iterating backwards. + last_maintainer_comment = None + stale_label_event_time = None + + for event in reversed(human_events): + if ( + not last_maintainer_comment + and event.get("actor", {}).get("login") in maintainers + and event.get("event") == "commented" + ): + last_maintainer_comment = event + + if ( + not stale_label_event_time + and event.get("event") == "labeled" + and event.get("label", {}).get("name") == STALE_LABEL_NAME + ): + stale_label_event_time = event["parsed_time"] + + if last_maintainer_comment and stale_label_event_time: + break + + last_author_action = next( + ( + e + for e in reversed(human_events) + if e.get("actor", {}).get("login") == issue_author + ), + None, + ) + + # Build and return the final summary report for the LLM agent. + last_human_event = human_events[-1] if human_events else None + last_human_actor = ( + last_human_event.get("actor", {}).get("login") + if last_human_event + else None + ) + + return { + "status": "success", + "issue_author": issue_author, + "current_labels": current_labels, + "last_maintainer_comment_text": ( + last_maintainer_comment.get("body") + if last_maintainer_comment + else None + ), + "last_maintainer_comment_time": ( + last_maintainer_comment["parsed_time"].isoformat() + if last_maintainer_comment + else None + ), + "last_author_event_time": ( + last_author_action["parsed_time"].isoformat() + if last_author_action + else None + ), + "last_author_action_type": ( + last_author_action.get("event") if last_author_action else "unknown" + ), + "last_human_action_type": ( + last_human_event.get("event") if last_human_event else "unknown" + ), + "last_human_commenter_is_maintainer": ( + last_human_actor in maintainers if last_human_actor else False + ), + "stale_label_applied_at": ( + stale_label_event_time.isoformat() + if stale_label_event_time + else None + ), + } + + except RequestException as e: + logger.error( + f"Failed to fetch comprehensive issue state for #{item_number}: {e}", + exc_info=True, + ) + return error_response( + f"Error getting comprehensive issue state for #{item_number}: {e}" + ) + + +def calculate_time_difference(timestamp_str: str) -> dict[str, Any]: + """Calculates the difference in hours between a UTC timestamp string and now. + + Args: + timestamp_str: An ISO 8601 formatted timestamp string. + + Returns: + A dictionary with the status and the time difference in hours, or an error + dictionary. + """ + try: + if not timestamp_str: + return error_response("Input timestamp is empty.") + event_time = dateutil.parser.isoparse(timestamp_str) + current_time_utc = datetime.now(timezone.utc) + time_difference = current_time_utc - event_time + hours_passed = time_difference.total_seconds() / 3600 + return {"status": "success", "hours_passed": hours_passed} + except (dateutil.parser.ParserError, TypeError) as e: + logger.error( + "Error calculating time difference for timestamp" + f" '{timestamp_str}': {e}", + exc_info=True, + ) + return error_response(f"Error calculating time difference: {e}") + + +def add_label_to_issue(item_number: int, label_name: str) -> dict[str, Any]: + """Adds a specific label to an issue. + + Args: + item_number: The issue number. + label_name: The name of the label to add. + + Returns: + A dictionary indicating the status of the operation. + """ + logger.debug(f"Adding label '{label_name}' to issue #{item_number}.") + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/labels" + try: + post_request(url, [label_name]) + logger.info( + f"Successfully added label '{label_name}' to issue #{item_number}." + ) + return {"status": "success"} + except RequestException as e: + logger.error(f"Failed to add '{label_name}' to issue #{item_number}: {e}") + return error_response(f"Error adding label: {e}") + + +def remove_label_from_issue( + item_number: int, label_name: str +) -> dict[str, Any]: + """Removes a specific label from an issue or PR. + + Args: + item_number: The issue number. + label_name: The name of the label to remove. + + Returns: + A dictionary indicating the status of the operation. + """ + logger.debug(f"Removing label '{label_name}' from issue #{item_number}.") + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/labels/{label_name}" + try: + delete_request(url) + logger.info( + f"Successfully removed label '{label_name}' from issue #{item_number}." + ) + return {"status": "success"} + except RequestException as e: + logger.error( + f"Failed to remove '{label_name}' from issue #{item_number}: {e}" + ) + return error_response(f"Error removing label: {e}") + + +def add_stale_label_and_comment(item_number: int) -> dict[str, Any]: + """Adds the 'stale' label to an issue and posts a comment explaining why. + + Args: + item_number: The issue number. + + Returns: + A dictionary indicating the status of the operation. + """ + logger.debug(f"Adding stale label and comment to issue #{item_number}.") + comment = ( + "This issue has been automatically marked as stale because it has not" + " had recent activity after a maintainer requested clarification. It" + " will be closed if no further activity occurs within" + f" {CLOSE_HOURS_AFTER_STALE_THRESHOLD / 24:.0f} days." + ) + try: + post_request( + f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/comments", + {"body": comment}, + ) + post_request( + f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/labels", + [STALE_LABEL_NAME], + ) + logger.info(f"Successfully marked issue #{item_number} as stale.") + return {"status": "success"} + except RequestException as e: + logger.error( + f"Failed to mark issue #{item_number} as stale: {e}", exc_info=True + ) + return error_response(f"Error marking issue as stale: {e}") + + +def close_as_stale(item_number: int) -> dict[str, Any]: + """Posts a final comment and closes an issue or PR as stale. + + Args: + item_number: The issue number. + + Returns: + A dictionary indicating the status of the operation. + """ + logger.debug(f"Closing issue #{item_number} as stale.") + comment = ( + "This has been automatically closed because it has been marked as stale" + f" for over {CLOSE_HOURS_AFTER_STALE_THRESHOLD / 24:.0f} days." + ) + try: + post_request( + f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/comments", + {"body": comment}, + ) + patch_request( + f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}", + {"state": "closed"}, + ) + logger.info(f"Successfully closed issue #{item_number} as stale.") + return {"status": "success"} + except RequestException as e: + logger.error( + f"Failed to close issue #{item_number} as stale: {e}", exc_info=True + ) + return error_response(f"Error closing issue: {e}") + + +# --- Agent Definition --- + +root_agent = Agent( + model=LLM_MODEL_NAME, + name="adk_repository_auditor_agent", + description=( + "Audits open issues to manage their state based on conversation" + " history." + ), + instruction=PROMPT_TEMPLATE.format( + OWNER=OWNER, + REPO=REPO, + STALE_LABEL_NAME=STALE_LABEL_NAME, + REQUEST_CLARIFICATION_LABEL=REQUEST_CLARIFICATION_LABEL, + STALE_HOURS_THRESHOLD=STALE_HOURS_THRESHOLD, + CLOSE_HOURS_AFTER_STALE_THRESHOLD=CLOSE_HOURS_AFTER_STALE_THRESHOLD, + ), + tools=[ + add_label_to_issue, + add_stale_label_and_comment, + calculate_time_difference, + close_as_stale, + get_all_open_issues, + get_issue_state, + get_repository_maintainers, + remove_label_from_issue, + ], +) diff --git a/contributing/samples/adk_stale_agent/main.py b/contributing/samples/adk_stale_agent/main.py new file mode 100644 index 0000000000..f6fba3fba0 --- /dev/null +++ b/contributing/samples/adk_stale_agent/main.py @@ -0,0 +1,74 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import time + +from adk_stale_agent.agent import root_agent +from adk_stale_agent.settings import OWNER +from adk_stale_agent.settings import REPO +from google.adk.cli.utils import logs +from google.adk.runners import InMemoryRunner +from google.genai import types + +logs.setup_adk_logger(level=logging.INFO) +logger = logging.getLogger("google_adk." + __name__) + +APP_NAME = "adk_stale_agent_app" +USER_ID = "adk_stale_agent_user" + + +async def main(): + """Initializes and runs the stale issue agent.""" + logger.info("--- Starting Stale Agent Run ---") + runner = InMemoryRunner(agent=root_agent, app_name=APP_NAME) + session = await runner.session_service.create_session( + user_id=USER_ID, app_name=APP_NAME + ) + + prompt_text = ( + "Find and process all open issues to manage staleness according to your" + " rules." + ) + logger.info(f"Initial Agent Prompt: {prompt_text}\n") + prompt_message = types.Content( + role="user", parts=[types.Part(text=prompt_text)] + ) + + async for event in runner.run_async( + user_id=USER_ID, session_id=session.id, new_message=prompt_message + ): + if ( + event.content + and event.content.parts + and hasattr(event.content.parts[0], "text") + ): + # Print the agent's "thoughts" and actions for logging purposes + logger.debug(f"** {event.author} (ADK): {event.content.parts[0].text}") + + logger.info(f"--- Stale Agent Run Finished---") + + +if __name__ == "__main__": + start_time = time.time() + logger.info(f"Initializing stale agent for repository: {OWNER}/{REPO}") + logger.info("-" * 80) + + asyncio.run(main()) + + logger.info("-" * 80) + end_time = time.time() + duration = end_time - start_time + logger.info(f"Script finished in {duration:.2f} seconds.") diff --git a/contributing/samples/adk_stale_agent/settings.py b/contributing/samples/adk_stale_agent/settings.py new file mode 100644 index 0000000000..1b71e451f3 --- /dev/null +++ b/contributing/samples/adk_stale_agent/settings.py @@ -0,0 +1,49 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from dotenv import load_dotenv + +# Load environment variables from a .env file for local testing +load_dotenv(override=True) + +# --- GitHub API Configuration --- +GITHUB_BASE_URL = "https://api.github.com" +GITHUB_TOKEN = os.getenv("GITHUB_TOKEN") +if not GITHUB_TOKEN: + raise ValueError("GITHUB_TOKEN environment variable not set") + +OWNER = os.getenv("OWNER", "google") +REPO = os.getenv("REPO", "adk-python") +LLM_MODEL_NAME = os.getenv("LLM_MODEL_NAME", "gemini-2.5-flash") + +STALE_LABEL_NAME = "stale" +REQUEST_CLARIFICATION_LABEL = "request clarification" + +# --- THRESHOLDS IN HOURS --- +# These values can be overridden in a .env file for rapid testing (e.g., STALE_HOURS_THRESHOLD=1) +# Default: 168 hours (7 days) +# The number of hours of inactivity after a maintainer comment before an issue is marked as stale. +STALE_HOURS_THRESHOLD = float(os.getenv("STALE_HOURS_THRESHOLD", 168)) + +# Default: 168 hours (7 days) +# The number of hours of inactivity after an issue is marked 'stale' before it is closed. +CLOSE_HOURS_AFTER_STALE_THRESHOLD = float( + os.getenv("CLOSE_HOURS_AFTER_STALE_THRESHOLD", 168) +) + +# --- BATCH SIZE CONFIGURATION --- +# The maximum number of oldest open issues to process in a single run of the bot. +ISSUES_PER_RUN = int(os.getenv("ISSUES_PER_RUN", 100)) diff --git a/contributing/samples/adk_stale_agent/utils.py b/contributing/samples/adk_stale_agent/utils.py new file mode 100644 index 0000000000..0efb051f72 --- /dev/null +++ b/contributing/samples/adk_stale_agent/utils.py @@ -0,0 +1,59 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from adk_stale_agent.settings import GITHUB_TOKEN +import requests + +_session = requests.Session() +_session.headers.update({ + "Authorization": f"token {GITHUB_TOKEN}", + "Accept": "application/vnd.github.v3+json", +}) + + +def get_request(url: str, params: dict[str, Any] | None = None) -> Any: + """Sends a GET request to the GitHub API.""" + response = _session.get(url, params=params or {}, timeout=60) + response.raise_for_status() + return response.json() + + +def post_request(url: str, payload: Any) -> Any: + """Sends a POST request to the GitHub API.""" + response = _session.post(url, json=payload, timeout=60) + response.raise_for_status() + return response.json() + + +def patch_request(url: str, payload: Any) -> Any: + """Sends a PATCH request to the GitHub API.""" + response = _session.patch(url, json=payload, timeout=60) + response.raise_for_status() + return response.json() + + +def delete_request(url: str) -> Any: + """Sends a DELETE request to the GitHub API.""" + response = _session.delete(url, timeout=60) + response.raise_for_status() + if response.status_code == 204: + return {"status": "success"} + return response.json() + + +def error_response(error_message: str) -> dict[str, Any]: + """Creates a standardized error dictionary for the agent.""" + return {"status": "error", "message": error_message} From 9d331abb4eb1c7cf65f1de2a698ebc3d4d072345 Mon Sep 17 00:00:00 2001 From: Giorgio Boa <35845425+gioboa@users.noreply.github.com> Date: Fri, 21 Nov 2025 13:38:57 -0800 Subject: [PATCH 20/63] ci: bump action scripts versions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Merge https://github.com/google/adk-python/pull/3638 Thanks for this great project 👏 This PR updates the GitHub actions dependencies to the latest version. ### Checklist - [X] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [X] I have performed a self-review of my own code. Co-authored-by: Hangfei Lin COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3638 from gioboa:ci/actions-versions f7d6f3b5233e8cb135c8af88d5b6e0ead8382055 PiperOrigin-RevId: 835343177 --- .github/workflows/analyze-releases-for-adk-docs-updates.yml | 4 ++-- .github/workflows/check-file-contents.yml | 2 +- .github/workflows/copybara-pr-handler.yml | 2 +- .github/workflows/discussion_answering.yml | 4 ++-- .github/workflows/isort.yml | 4 ++-- .github/workflows/pr-triage.yml | 4 ++-- .github/workflows/pyink.yml | 4 ++-- .github/workflows/python-unit-tests.yml | 4 ++-- .github/workflows/triage.yml | 4 ++-- .github/workflows/upload-adk-docs-to-vertex-ai-search.yml | 4 ++-- 10 files changed, 18 insertions(+), 18 deletions(-) diff --git a/.github/workflows/analyze-releases-for-adk-docs-updates.yml b/.github/workflows/analyze-releases-for-adk-docs-updates.yml index c8a86fac66..21414ae534 100644 --- a/.github/workflows/analyze-releases-for-adk-docs-updates.yml +++ b/.github/workflows/analyze-releases-for-adk-docs-updates.yml @@ -16,10 +16,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' diff --git a/.github/workflows/check-file-contents.yml b/.github/workflows/check-file-contents.yml index 861e2247a0..bb575e0f20 100644 --- a/.github/workflows/check-file-contents.yml +++ b/.github/workflows/check-file-contents.yml @@ -24,7 +24,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Code - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 2 diff --git a/.github/workflows/copybara-pr-handler.yml b/.github/workflows/copybara-pr-handler.yml index 670389c527..4ca3c48803 100644 --- a/.github/workflows/copybara-pr-handler.yml +++ b/.github/workflows/copybara-pr-handler.yml @@ -25,7 +25,7 @@ jobs: steps: - name: Check for Copybara commits and close PRs - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: github-token: ${{ secrets.ADK_TRIAGE_AGENT }} script: | diff --git a/.github/workflows/discussion_answering.yml b/.github/workflows/discussion_answering.yml index 71c06ba9f6..d9bfffc361 100644 --- a/.github/workflows/discussion_answering.yml +++ b/.github/workflows/discussion_answering.yml @@ -15,10 +15,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' diff --git a/.github/workflows/isort.yml b/.github/workflows/isort.yml index e1a087742c..b8b24da5ce 100644 --- a/.github/workflows/isort.yml +++ b/.github/workflows/isort.yml @@ -26,12 +26,12 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 2 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.x' diff --git a/.github/workflows/pr-triage.yml b/.github/workflows/pr-triage.yml index a8a2082094..55b088b505 100644 --- a/.github/workflows/pr-triage.yml +++ b/.github/workflows/pr-triage.yml @@ -20,10 +20,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' diff --git a/.github/workflows/pyink.yml b/.github/workflows/pyink.yml index ef9e72e453..0822757fa0 100644 --- a/.github/workflows/pyink.yml +++ b/.github/workflows/pyink.yml @@ -26,12 +26,12 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 2 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.x' diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index 42b6174813..8f8f46e953 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -29,10 +29,10 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/triage.yml b/.github/workflows/triage.yml index 0d310cee89..97ddc0efa7 100644 --- a/.github/workflows/triage.yml +++ b/.github/workflows/triage.yml @@ -18,10 +18,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' diff --git a/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml b/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml index 9b1f042917..bce7598c2f 100644 --- a/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml +++ b/.github/workflows/upload-adk-docs-to-vertex-ai-search.yml @@ -13,7 +13,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Clone adk-docs repository run: git clone https://github.com/google/adk-docs.git /tmp/adk-docs @@ -22,7 +22,7 @@ jobs: run: git clone https://github.com/google/adk-python.git /tmp/adk-python - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' From 609c6172d92296e518de8b24dcbb3331304077cb Mon Sep 17 00:00:00 2001 From: Adrian Altermatt Date: Fri, 21 Nov 2025 13:48:30 -0800 Subject: [PATCH 21/63] docs: too many E(inv=2, role=user) plus reformatting Merge https://github.com/google/adk-python/pull/3538 Main change from: E(inv=2, role=user), E(inv=2, role=model), E(inv=2, role=user), To: E(inv=2, role=user), E(inv=2, role=model) I think the last E(inv=2, role=user) was wrong. Also reformatted. Co-authored-by: Hangfei Lin COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3538 from adrianad:patch-1 627b933bdc3e00e45f704edf95448281e32d127c PiperOrigin-RevId: 835346467 --- src/google/adk/apps/compaction.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/google/adk/apps/compaction.py b/src/google/adk/apps/compaction.py index a6f55f9ad6..4511b1b96e 100644 --- a/src/google/adk/apps/compaction.py +++ b/src/google/adk/apps/compaction.py @@ -72,9 +72,10 @@ async def _run_compaction_for_sliding_window( beginning. - A `CompactedEvent` is generated, summarizing events within `invocation_id` range [1, 2]. - - The session now contains: `[E(inv=1, role=user), E(inv=1, role=model), - E(inv=2, role=user), E(inv=2, role=model), E(inv=2, role=user), - CompactedEvent(inv=[1, 2])]`. + - The session now contains: `[ + E(inv=1, role=user), E(inv=1, role=model), + E(inv=2, role=user), E(inv=2, role=model), + CompactedEvent(inv=[1, 2])]`. 2. **After `invocation_id` 3 events are added:** - No compaction happens yet, because only 1 new invocation (`inv=3`) @@ -91,10 +92,13 @@ async def _run_compaction_for_sliding_window( - The new compaction range is from `invocation_id` 2 to 4. - A new `CompactedEvent` is generated, summarizing events within `invocation_id` range [2, 4]. - - The session now contains: `[E(inv=1, role=user), E(inv=1, role=model), - E(inv=2, role=user), E(inv=2, role=model), E(inv=2, role=user), - CompactedEvent(inv=[1, 2]), E(inv=3, role=user), E(inv=3, role=model), - E(inv=4, role=user), E(inv=4, role=model), CompactedEvent(inv=[2, 4])]`. + - The session now contains: `[ + E(inv=1, role=user), E(inv=1, role=model), + E(inv=2, role=user), E(inv=2, role=model), + CompactedEvent(inv=[1, 2]), + E(inv=3, role=user), E(inv=3, role=model), + E(inv=4, role=user), E(inv=4, role=model), + CompactedEvent(inv=[2, 4])]`. Args: From 2247a45922afdf0a733239b619f45601d9b325ec Mon Sep 17 00:00:00 2001 From: saroj rout Date: Fri, 21 Nov 2025 14:23:33 -0800 Subject: [PATCH 22/63] feat(agents): add validation for unique sub-agent names (#3557) Merge https://github.com/google/adk-python/pull/3576 **Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.** ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** - Closes: #3557 - Related: #_issue_number_ **2. Or, if no issue exists, describe the change:** _If applicable, please follow the issue templates to provide as much detail as possible._ **Problem:** When creating a BaseAgent with multiple sub-agents, there was no validation to ensure that all sub-agents have unique names. This could lead to confusion when trying to find or reference specific sub-agents by name, as duplicate names would make it ambiguous which agent is being referenced. **Solution:** Added a @field_validator for the sub_agents field in BaseAgent that validates all sub-agents have unique names. The validator: Checks for duplicate names in the sub-agents list Raises a ValueError with a clear error message listing all duplicate names found Returns the validated list if all names are unique Handles edge cases like empty lists gracefully ### Testing Plan _Please describe the tests that you ran to verify your changes. This is required for all PRs that are not small documentation or typo fixes._ **Unit Tests:** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. _Please include a summary of passed `pytest` results._ Added 4 new test cases in tests/unittests/agents/test_base_agent.py: test_validate_sub_agents_unique_names_single_duplicate: Verifies that a single duplicate name raises ValueError test_validate_sub_agents_unique_names_multiple_duplicates: Verifies that multiple duplicate names are all reported in the error message test_validate_sub_agents_unique_names_no_duplicates: Verifies that unique names pass validation successfully test_validate_sub_agents_unique_names_empty_list: Verifies that empty sub-agents list passes validation All tests pass locally. You can run with: pytest tests/unittests/agents/test_base_agent.py::test_validate_sub_agents_unique_names_single_duplicate tests/unittests/agents/test_base_agent.py::test_validate_sub_agents_unique_names_multiple_duplicates tests/unittests/agents/test_base_agent.py::test_validate_sub_agents_unique_names_no_duplicates tests/unittests/agents/test_base_agent.py::test_validate_sub_agents_unique_names_empty_list -v **Manual End-to-End (E2E) Tests:** _Please provide instructions on how to manually test your changes, including any necessary setup or configuration. Please provide logs or screenshots to help reviewers better understand the fix._ Test Case 1: Duplicate names should raise error from google.adk.agents import Agent agent1 = Agent(name="sub_agent", model="gemini-2.5-flash") agent2 = Agent(name="sub_agent", model="gemini-2.5-flash") # Same name # This should raise ValueError try: parent = Agent( name="parent", model="gemini-2.5-flash", sub_agents=[agent1, agent2] ) except ValueError as e: print(f"Expected error: {e}") # Output: Found duplicate sub-agent names: `sub_agent`. All sub-agents must have unique names. Test Case 2: Unique names should work from google.adk.agents import Agent agent1 = Agent(name="agent1", model="gemini-2.5-flash") agent2 = Agent(name="agent2", model="gemini-2.5-flash") # This should work without error parent = Agent( name="parent", model="gemini-2.5-flash", sub_agents=[agent1, agent2] ) print("Success: Unique names validated correctly") ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [x] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [x] I have manually tested my changes end-to-end. - [x] Any dependent changes have been merged and published in downstream modules. ### Additional context This change adds validation at the BaseAgent level, so it automatically applies to all agent types that inherit from BaseAgent (e.g., LlmAgent, LoopAgent, etc.). The validation uses Pydantic's field validator system, which runs during object initialization, ensuring the constraint is enforced early and consistently. The error message clearly identifies which names are duplicated, making it easy for developers to fix the issue: COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3576 from sarojrout:feat/validate-unique-sub-agent-names 07adf1f9a5fc935389eb9dfa3cbc1311f551ebe3 PiperOrigin-RevId: 835358118 --- src/google/adk/agents/base_agent.py | 40 +++++++++ tests/unittests/agents/test_base_agent.py | 104 ++++++++++++++++++++++ 2 files changed, 144 insertions(+) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index a644cb8b90..fccde1cb6f 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -563,6 +563,46 @@ def validate_name(cls, value: str): ) return value + @field_validator('sub_agents', mode='after') + @classmethod + def validate_sub_agents_unique_names( + cls, value: list[BaseAgent] + ) -> list[BaseAgent]: + """Validates that all sub-agents have unique names. + + Args: + value: The list of sub-agents to validate. + + Returns: + The validated list of sub-agents. + + Raises: + ValueError: If duplicate sub-agent names are found. + """ + if not value: + return value + + seen_names: set[str] = set() + duplicates: set[str] = set() + + for sub_agent in value: + name = sub_agent.name + if name in seen_names: + duplicates.add(name) + else: + seen_names.add(name) + + if duplicates: + duplicate_names_str = ', '.join( + f'`{name}`' for name in sorted(duplicates) + ) + raise ValueError( + f'Found duplicate sub-agent names: {duplicate_names_str}. ' + 'All sub-agents must have unique names.' + ) + + return value + def __set_parent_agent_for_sub_agents(self) -> BaseAgent: for sub_agent in self.sub_agents: if sub_agent.parent_agent is not None: diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index 663179f670..860cc8d4f0 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -854,6 +854,110 @@ def test_set_parent_agent_for_sub_agent_twice( ) +def test_validate_sub_agents_unique_names_single_duplicate( + request: pytest.FixtureRequest, +): + """Test that duplicate sub-agent names raise ValueError.""" + duplicate_name = f'{request.function.__name__}_duplicate_agent' + sub_agent_1 = _TestingAgent(name=duplicate_name) + sub_agent_2 = _TestingAgent(name=duplicate_name) + + with pytest.raises(ValueError, match='Found duplicate sub-agent names'): + _ = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=[sub_agent_1, sub_agent_2], + ) + + +def test_validate_sub_agents_unique_names_multiple_duplicates( + request: pytest.FixtureRequest, +): + """Test that multiple duplicate sub-agent names are all reported.""" + duplicate_name_1 = f'{request.function.__name__}_duplicate_1' + duplicate_name_2 = f'{request.function.__name__}_duplicate_2' + + sub_agents = [ + _TestingAgent(name=duplicate_name_1), + _TestingAgent(name=f'{request.function.__name__}_unique'), + _TestingAgent(name=duplicate_name_1), # First duplicate + _TestingAgent(name=duplicate_name_2), + _TestingAgent(name=duplicate_name_2), # Second duplicate + ] + + with pytest.raises(ValueError) as exc_info: + _ = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=sub_agents, + ) + + error_message = str(exc_info.value) + # Verify each duplicate name appears exactly once in the error message + assert error_message.count(duplicate_name_1) == 1 + assert error_message.count(duplicate_name_2) == 1 + # Verify both duplicate names are present + assert duplicate_name_1 in error_message + assert duplicate_name_2 in error_message + + +def test_validate_sub_agents_unique_names_triple_duplicate( + request: pytest.FixtureRequest, +): + """Test that a name appearing three times is reported only once.""" + duplicate_name = f'{request.function.__name__}_triple_duplicate' + + sub_agents = [ + _TestingAgent(name=duplicate_name), + _TestingAgent(name=f'{request.function.__name__}_unique'), + _TestingAgent(name=duplicate_name), # Second occurrence + _TestingAgent(name=duplicate_name), # Third occurrence + ] + + with pytest.raises(ValueError) as exc_info: + _ = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=sub_agents, + ) + + error_message = str(exc_info.value) + # Verify the duplicate name appears exactly once in the error message + # (not three times even though it appears three times in the list) + assert error_message.count(duplicate_name) == 1 + assert duplicate_name in error_message + + +def test_validate_sub_agents_unique_names_no_duplicates( + request: pytest.FixtureRequest, +): + """Test that unique sub-agent names pass validation.""" + sub_agents = [ + _TestingAgent(name=f'{request.function.__name__}_sub_agent_1'), + _TestingAgent(name=f'{request.function.__name__}_sub_agent_2'), + _TestingAgent(name=f'{request.function.__name__}_sub_agent_3'), + ] + + parent = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=sub_agents, + ) + + assert len(parent.sub_agents) == 3 + assert parent.sub_agents[0].name == f'{request.function.__name__}_sub_agent_1' + assert parent.sub_agents[1].name == f'{request.function.__name__}_sub_agent_2' + assert parent.sub_agents[2].name == f'{request.function.__name__}_sub_agent_3' + + +def test_validate_sub_agents_unique_names_empty_list( + request: pytest.FixtureRequest, +): + """Test that empty sub-agents list passes validation.""" + parent = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=[], + ) + + assert len(parent.sub_agents) == 0 + + if __name__ == '__main__': pytest.main([__file__]) From 777dba3033a9a14667fb009ba017f648177be41d Mon Sep 17 00:00:00 2001 From: davidkl97 Date: Fri, 21 Nov 2025 14:49:06 -0800 Subject: [PATCH 23/63] feat(tools): Add an option to disallow propagating runner plugins to AgentTool runner Merge https://github.com/google/adk-python/pull/2779 Fixes #2780 ### testing plan not available as is doesn't introduce new functionality Co-authored-by: Wei Sun (Jack) COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/2779 from davidkl97:feature/agent-tool-plugins a602c808789f3daeed6244e352a6fb8fb6972de3 PiperOrigin-RevId: 835366974 --- src/google/adk/tools/agent_tool.py | 27 +++++- tests/unittests/tools/test_agent_tool.py | 109 +++++++++++++++++++++++ 2 files changed, 133 insertions(+), 3 deletions(-) diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 702f6e43aa..abfcf36979 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -45,11 +45,22 @@ class AgentTool(BaseTool): Attributes: agent: The agent to wrap. skip_summarization: Whether to skip summarization of the agent output. + include_plugins: Whether to propagate plugins from the parent runner context + to the agent's runner. When True (default), the agent will inherit all + plugins from its parent. Set to False to run the agent with an isolated + plugin environment. """ - def __init__(self, agent: BaseAgent, skip_summarization: bool = False): + def __init__( + self, + agent: BaseAgent, + skip_summarization: bool = False, + *, + include_plugins: bool = True, + ): self.agent = agent self.skip_summarization: bool = skip_summarization + self.include_plugins = include_plugins super().__init__(name=agent.name, description=agent.description) @@ -130,6 +141,11 @@ async def run_async( invocation_context.app_name if invocation_context else None ) child_app_name = parent_app_name or self.agent.name + plugins = ( + tool_context._invocation_context.plugin_manager.plugins + if self.include_plugins + else None + ) runner = Runner( app_name=child_app_name, agent=self.agent, @@ -137,7 +153,7 @@ async def run_async( session_service=InMemorySessionService(), memory_service=InMemoryMemoryService(), credential_service=tool_context._invocation_context.credential_service, - plugins=list(tool_context._invocation_context.plugin_manager.plugins), + plugins=plugins, ) state_dict = { @@ -192,7 +208,9 @@ def from_config( agent_tool_config.agent, config_abs_path ) return cls( - agent=agent, skip_summarization=agent_tool_config.skip_summarization + agent=agent, + skip_summarization=agent_tool_config.skip_summarization, + include_plugins=agent_tool_config.include_plugins, ) @@ -204,3 +222,6 @@ class AgentToolConfig(BaseToolConfig): skip_summarization: bool = False """Whether to skip summarization of the agent output.""" + + include_plugins: bool = True + """Whether to include plugins from parent runner context.""" diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index 85e8b9caa1..f2bc97b207 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -570,3 +570,112 @@ class CustomInput(BaseModel): # Should have string response schema for VERTEX_AI when no output_schema assert declaration.response is not None assert declaration.response.type == types.Type.STRING + + +def test_include_plugins_default_true(): + """Test that plugins are propagated by default (include_plugins=True).""" + + # Create a test plugin that tracks callbacks + class TrackingPlugin(BasePlugin): + + def __init__(self, name: str): + super().__init__(name) + self.before_agent_calls = 0 + + async def before_agent_callback(self, **kwargs): + self.before_agent_calls += 1 + + tracking_plugin = TrackingPlugin(name='tracking') + + mock_model = testing_utils.MockModel.create( + responses=[function_call_no_schema, 'response1', 'response2'] + ) + + tool_agent = Agent( + name='tool_agent', + model=mock_model, + ) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[AgentTool(agent=tool_agent)], # Default include_plugins=True + ) + + runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin]) + runner.run('test1') + + # Plugin should be called for both root_agent and tool_agent + assert tracking_plugin.before_agent_calls == 2 + + +def test_include_plugins_explicit_true(): + """Test that plugins are propagated when include_plugins=True.""" + + class TrackingPlugin(BasePlugin): + + def __init__(self, name: str): + super().__init__(name) + self.before_agent_calls = 0 + + async def before_agent_callback(self, **kwargs): + self.before_agent_calls += 1 + + tracking_plugin = TrackingPlugin(name='tracking') + + mock_model = testing_utils.MockModel.create( + responses=[function_call_no_schema, 'response1', 'response2'] + ) + + tool_agent = Agent( + name='tool_agent', + model=mock_model, + ) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[AgentTool(agent=tool_agent, include_plugins=True)], + ) + + runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin]) + runner.run('test1') + + # Plugin should be called for both root_agent and tool_agent + assert tracking_plugin.before_agent_calls == 2 + + +def test_include_plugins_false(): + """Test that plugins are NOT propagated when include_plugins=False.""" + + class TrackingPlugin(BasePlugin): + + def __init__(self, name: str): + super().__init__(name) + self.before_agent_calls = 0 + + async def before_agent_callback(self, **kwargs): + self.before_agent_calls += 1 + + tracking_plugin = TrackingPlugin(name='tracking') + + mock_model = testing_utils.MockModel.create( + responses=[function_call_no_schema, 'response1', 'response2'] + ) + + tool_agent = Agent( + name='tool_agent', + model=mock_model, + ) + + root_agent = Agent( + name='root_agent', + model=mock_model, + tools=[AgentTool(agent=tool_agent, include_plugins=False)], + ) + + runner = testing_utils.InMemoryRunner(root_agent, plugins=[tracking_plugin]) + runner.run('test1') + + # Plugin should only be called for root_agent, not tool_agent + assert tracking_plugin.before_agent_calls == 1 From 52674e7fac6b7689f0e3871d41c4523e13471a7e Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Fri, 21 Nov 2025 15:24:49 -0800 Subject: [PATCH 24/63] fix: Update AgentTool to use Agent's description when input_schema is provided in FunctionDeclaration Co-authored-by: Xuan Yang PiperOrigin-RevId: 835379243 --- src/google/adk/tools/agent_tool.py | 2 ++ tests/unittests/tools/test_agent_tool.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index abfcf36979..46d8616619 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -79,6 +79,8 @@ def _get_declaration(self) -> types.FunctionDeclaration: result = _automatic_function_calling_util.build_function_declaration( func=self.agent.input_schema, variant=self._api_variant ) + # Override the description with the agent's description + result.description = self.agent.description else: result = types.FunctionDeclaration( parameters=types.Schema( diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index f2bc97b207..a9723b4347 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -679,3 +679,26 @@ async def before_agent_callback(self, **kwargs): # Plugin should only be called for root_agent, not tool_agent assert tracking_plugin.before_agent_calls == 1 + + +def test_agent_tool_description_with_input_schema(): + """Test that agent description is propagated when using input_schema.""" + + class CustomInput(BaseModel): + """This is the Pydantic model docstring.""" + + custom_input: str + + agent_description = 'This is the agent description that should be used' + tool_agent = Agent( + name='tool_agent', + model=testing_utils.MockModel.create(responses=['test response']), + description=agent_description, + input_schema=CustomInput, + ) + + agent_tool = AgentTool(agent=tool_agent) + declaration = agent_tool._get_declaration() + + # The description should come from the agent, not the Pydantic model + assert declaration.description == agent_description From 2e1f730c3bc0eb454b76d7f36b7b9f1da7304cfe Mon Sep 17 00:00:00 2001 From: George Weale Date: Fri, 21 Nov 2025 15:55:34 -0800 Subject: [PATCH 25/63] fix: Update LiteLLM system instruction role from "developer" to "system" This change replaces the use of `ChatCompletionDeveloperMessage` with `ChatCompletionSystemMessage` and sets the role to "system" for providing system instructions to LiteLLM models Close #3657 Co-authored-by: George Weale PiperOrigin-RevId: 835388738 --- src/google/adk/models/lite_llm.py | 6 +++--- tests/unittests/models/test_litellm.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index c9712974ea..9e3698b190 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -39,8 +39,8 @@ from litellm import acompletion from litellm import ChatCompletionAssistantMessage from litellm import ChatCompletionAssistantToolCall -from litellm import ChatCompletionDeveloperMessage from litellm import ChatCompletionMessageToolCall +from litellm import ChatCompletionSystemMessage from litellm import ChatCompletionToolMessage from litellm import ChatCompletionUserMessage from litellm import completion @@ -983,8 +983,8 @@ def _get_completion_inputs( if llm_request.config.system_instruction: messages.insert( 0, - ChatCompletionDeveloperMessage( - role="developer", + ChatCompletionSystemMessage( + role="system", content=llm_request.config.system_instruction, ), ) diff --git a/tests/unittests/models/test_litellm.py b/tests/unittests/models/test_litellm.py index c486806d37..f65fc77a61 100644 --- a/tests/unittests/models/test_litellm.py +++ b/tests/unittests/models/test_litellm.py @@ -1195,7 +1195,7 @@ async def test_generate_content_async_with_system_instruction( _, kwargs = mock_acompletion.call_args assert kwargs["model"] == "test_model" - assert kwargs["messages"][0]["role"] == "developer" + assert kwargs["messages"][0]["role"] == "system" assert kwargs["messages"][0]["content"] == "Test system instruction" assert kwargs["messages"][1]["role"] == "user" assert kwargs["messages"][1]["content"] == "Test prompt" From a9a418ba87b27baf435c147d80e105e69fae316a Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Fri, 21 Nov 2025 20:55:57 -0800 Subject: [PATCH 26/63] fix: Remove distructive validation Co-authored-by: Shangjie Chen PiperOrigin-RevId: 835466120 --- src/google/adk/agents/base_agent.py | 40 --------- tests/unittests/agents/test_base_agent.py | 104 ---------------------- 2 files changed, 144 deletions(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index fccde1cb6f..a644cb8b90 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -563,46 +563,6 @@ def validate_name(cls, value: str): ) return value - @field_validator('sub_agents', mode='after') - @classmethod - def validate_sub_agents_unique_names( - cls, value: list[BaseAgent] - ) -> list[BaseAgent]: - """Validates that all sub-agents have unique names. - - Args: - value: The list of sub-agents to validate. - - Returns: - The validated list of sub-agents. - - Raises: - ValueError: If duplicate sub-agent names are found. - """ - if not value: - return value - - seen_names: set[str] = set() - duplicates: set[str] = set() - - for sub_agent in value: - name = sub_agent.name - if name in seen_names: - duplicates.add(name) - else: - seen_names.add(name) - - if duplicates: - duplicate_names_str = ', '.join( - f'`{name}`' for name in sorted(duplicates) - ) - raise ValueError( - f'Found duplicate sub-agent names: {duplicate_names_str}. ' - 'All sub-agents must have unique names.' - ) - - return value - def __set_parent_agent_for_sub_agents(self) -> BaseAgent: for sub_agent in self.sub_agents: if sub_agent.parent_agent is not None: diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index 860cc8d4f0..663179f670 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -854,110 +854,6 @@ def test_set_parent_agent_for_sub_agent_twice( ) -def test_validate_sub_agents_unique_names_single_duplicate( - request: pytest.FixtureRequest, -): - """Test that duplicate sub-agent names raise ValueError.""" - duplicate_name = f'{request.function.__name__}_duplicate_agent' - sub_agent_1 = _TestingAgent(name=duplicate_name) - sub_agent_2 = _TestingAgent(name=duplicate_name) - - with pytest.raises(ValueError, match='Found duplicate sub-agent names'): - _ = _TestingAgent( - name=f'{request.function.__name__}_parent', - sub_agents=[sub_agent_1, sub_agent_2], - ) - - -def test_validate_sub_agents_unique_names_multiple_duplicates( - request: pytest.FixtureRequest, -): - """Test that multiple duplicate sub-agent names are all reported.""" - duplicate_name_1 = f'{request.function.__name__}_duplicate_1' - duplicate_name_2 = f'{request.function.__name__}_duplicate_2' - - sub_agents = [ - _TestingAgent(name=duplicate_name_1), - _TestingAgent(name=f'{request.function.__name__}_unique'), - _TestingAgent(name=duplicate_name_1), # First duplicate - _TestingAgent(name=duplicate_name_2), - _TestingAgent(name=duplicate_name_2), # Second duplicate - ] - - with pytest.raises(ValueError) as exc_info: - _ = _TestingAgent( - name=f'{request.function.__name__}_parent', - sub_agents=sub_agents, - ) - - error_message = str(exc_info.value) - # Verify each duplicate name appears exactly once in the error message - assert error_message.count(duplicate_name_1) == 1 - assert error_message.count(duplicate_name_2) == 1 - # Verify both duplicate names are present - assert duplicate_name_1 in error_message - assert duplicate_name_2 in error_message - - -def test_validate_sub_agents_unique_names_triple_duplicate( - request: pytest.FixtureRequest, -): - """Test that a name appearing three times is reported only once.""" - duplicate_name = f'{request.function.__name__}_triple_duplicate' - - sub_agents = [ - _TestingAgent(name=duplicate_name), - _TestingAgent(name=f'{request.function.__name__}_unique'), - _TestingAgent(name=duplicate_name), # Second occurrence - _TestingAgent(name=duplicate_name), # Third occurrence - ] - - with pytest.raises(ValueError) as exc_info: - _ = _TestingAgent( - name=f'{request.function.__name__}_parent', - sub_agents=sub_agents, - ) - - error_message = str(exc_info.value) - # Verify the duplicate name appears exactly once in the error message - # (not three times even though it appears three times in the list) - assert error_message.count(duplicate_name) == 1 - assert duplicate_name in error_message - - -def test_validate_sub_agents_unique_names_no_duplicates( - request: pytest.FixtureRequest, -): - """Test that unique sub-agent names pass validation.""" - sub_agents = [ - _TestingAgent(name=f'{request.function.__name__}_sub_agent_1'), - _TestingAgent(name=f'{request.function.__name__}_sub_agent_2'), - _TestingAgent(name=f'{request.function.__name__}_sub_agent_3'), - ] - - parent = _TestingAgent( - name=f'{request.function.__name__}_parent', - sub_agents=sub_agents, - ) - - assert len(parent.sub_agents) == 3 - assert parent.sub_agents[0].name == f'{request.function.__name__}_sub_agent_1' - assert parent.sub_agents[1].name == f'{request.function.__name__}_sub_agent_2' - assert parent.sub_agents[2].name == f'{request.function.__name__}_sub_agent_3' - - -def test_validate_sub_agents_unique_names_empty_list( - request: pytest.FixtureRequest, -): - """Test that empty sub-agents list passes validation.""" - parent = _TestingAgent( - name=f'{request.function.__name__}_parent', - sub_agents=[], - ) - - assert len(parent.sub_agents) == 0 - - if __name__ == '__main__': pytest.main([__file__]) From cf21ca358478919207049695ba6b31dc6e0b2673 Mon Sep 17 00:00:00 2001 From: AlexeyChernenkoPlato Date: Sat, 22 Nov 2025 09:33:36 -0800 Subject: [PATCH 27/63] fix: double function response processing issue Merge https://github.com/google/adk-python/pull/2588 ## Description Fixes an issue in `base_llm_flow.py` where, in Bidi-streaming (live) mode, the multi-agent structure causes duplicated responses after tool calling. ## Problem In Bidi-streaming (live) mode, when utilizing a multi-agent structure, the leaf-level sub-agent and its parent agent both process the same function call response, leading to duplicate replies. This duplication occurs because the parent agent's live connection remains open while initiating a new connection with the child agent. ## Root Cause The issue originated from the placement of agent transfer logic in the `_postprocess_live` method at lines 547-557. When a `transfer_to_agent` function call was made: 1. The function response was processed in `_postprocess_live` 2. A recursive call to `agent_to_run.run_live` was initiated 3. This prevented the closure of the parent agent's connection at line 175 of the `run_live` method, as that code path was never reached 4. Both the parent and child agents remained active, causing both to process subsequent function responses ## Solution This PR addresses the issue by ensuring the parent agent's live connection is closed before initiating a new one with the child agent. Changes made: **Connection Management**: Moved the agent transfer logic from `_postprocess_live` method to the `run_live` method, specifically: - Removed agent transfer handling from lines 547-557 in `_postprocess_live` - Added agent transfer handling after connection closure at lines 176-184 in `run_live` **Code Refactoring**: The agent transfer now occurs in the proper sequence: 1. Parent agent processes the `transfer_to_agent` function response 2. Parent agent's live connection is properly closed (line 175) 3. New connection with child agent is initiated (line 182) 4. Child agent handles subsequent function calls without duplication **Improved Flow Control**: This ensures that each agent processes function call responses without duplication, maintaining proper connection lifecycle management in multi-agent structures. ## Testing To verify this fix works correctly: 1. **Multi-Agent Structure Test**: Set up a multi-agent structure with a parent agent that transfers to a child agent via `transfer_to_agent` function call 2. **Bidi-Streaming Mode**: Enable Bidi-streaming (live) mode in the configuration 3. **Function Call Verification**: Trigger a function call that results in agent transfer 4. **Response Monitoring**: Verify that only one response is generated (not duplicated) 5. **Connection Management**: Confirm that parent agent's connection is properly closed before child agent starts **Expected Behavior**: - Single function response per call - Clean agent handoffs without connection leaks - Proper connection lifecycle management ## Backward Compatibility This change is **fully backward compatible**: - No changes to public APIs or method signatures - Existing single-agent flows remain unaffected - Non-live (regular async) flows continue to work as before - Only affects the internal flow control in live multi-agent scenarios Co-authored-by: Hangfei Lin COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/2588 from AlexeyChernenkoPlato:fix/double-function-response-processing-issue 3339260a4e007251137d199bdcef0ddef4487b03 PiperOrigin-RevId: 835619170 --- .../adk/flows/llm_flows/base_llm_flow.py | 35 +++++++++++++------ 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index db50e77809..824cd26be1 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -156,7 +156,7 @@ async def run_live( break logger.debug('Receive new event: %s', event) yield event - # send back the function response + # send back the function response to models if event.get_function_responses(): logger.debug( 'Sending back last function response event: %s', event @@ -164,6 +164,16 @@ async def run_live( invocation_context.live_request_queue.send_content( event.content ) + # We handle agent transfer here in `run_live` rather than + # in `_postprocess_live` to prevent duplication of function + # response processing. If agent transfer were handled in + # `_postprocess_live`, events yielded from child agent's + # `run_live` would bubble up to parent agent's `run_live`, + # causing `event.get_function_responses()` to be true in both + # child and parent, and `send_content()` to be called twice for + # the same function response. By handling agent transfer here, + # we ensure that only child agent processes its own function + # responses after the transfer. if ( event.content and event.content.parts @@ -174,7 +184,21 @@ async def run_live( await asyncio.sleep(DEFAULT_TRANSFER_AGENT_DELAY) # cancel the tasks that belongs to the closed connection. send_task.cancel() + logger.debug('Closing live connection') await llm_connection.close() + logger.debug('Live connection closed.') + # transfer to the sub agent. + transfer_to_agent = event.actions.transfer_to_agent + if transfer_to_agent: + logger.debug('Transferring to agent: %s', transfer_to_agent) + agent_to_run = self._get_agent_to_run( + invocation_context, transfer_to_agent + ) + async with Aclosing( + agent_to_run.run_live(invocation_context) + ) as agen: + async for item in agen: + yield item if ( event.content and event.content.parts @@ -638,15 +662,6 @@ async def _postprocess_live( ) yield final_event - transfer_to_agent = function_response_event.actions.transfer_to_agent - if transfer_to_agent: - agent_to_run = self._get_agent_to_run( - invocation_context, transfer_to_agent - ) - async with Aclosing(agent_to_run.run_live(invocation_context)) as agen: - async for item in agen: - yield item - async def _postprocess_run_processors_async( self, invocation_context: InvocationContext, llm_response: LlmResponse ) -> AsyncGenerator[Event, None]: From a1c09b724bb37513eaabaff9643eeaa68014f14d Mon Sep 17 00:00:00 2001 From: Kristen Pereira <26pkristen@gmail.com> Date: Mon, 24 Nov 2025 16:06:24 -0800 Subject: [PATCH 28/63] fix: Windows Path Handling and Normalize Cross-Platform Path Resolution in AgentLoader Merge https://github.com/google/adk-python/pull/3609 Co-authored-by: George Weale COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3609 from p-kris10:fix/windows-cmd 8cb0310bd4450097a0a7714eaac6521b6a447442 PiperOrigin-RevId: 836395714 --- src/google/adk/agents/config_agent_utils.py | 2 +- src/google/adk/cli/utils/agent_loader.py | 15 +-- tests/unittests/agents/test_agent_config.py | 93 +++++++++++++++++++ .../unittests/cli/utils/test_agent_loader.py | 41 ++++++++ 4 files changed, 143 insertions(+), 8 deletions(-) diff --git a/src/google/adk/agents/config_agent_utils.py b/src/google/adk/agents/config_agent_utils.py index 7982a9cf59..38ba2e2578 100644 --- a/src/google/adk/agents/config_agent_utils.py +++ b/src/google/adk/agents/config_agent_utils.py @@ -132,7 +132,7 @@ def resolve_agent_reference( else: return from_config( os.path.join( - referencing_agent_config_abs_path.rsplit("/", 1)[0], + os.path.dirname(referencing_agent_config_abs_path), ref_config.config_path, ) ) diff --git a/src/google/adk/cli/utils/agent_loader.py b/src/google/adk/cli/utils/agent_loader.py index 9661df6fe8..d6965e5bbb 100644 --- a/src/google/adk/cli/utils/agent_loader.py +++ b/src/google/adk/cli/utils/agent_loader.py @@ -58,7 +58,7 @@ class AgentLoader(BaseAgentLoader): """ def __init__(self, agents_dir: str): - self.agents_dir = agents_dir.rstrip("/") + self.agents_dir = str(Path(agents_dir)) self._original_sys_path = None self._agent_cache: dict[str, Union[BaseAgent, App]] = {} @@ -272,12 +272,13 @@ def _perform_load(self, agent_name: str) -> Union[BaseAgent, App]: f"No root_agent found for '{agent_name}'. Searched in" f" '{actual_agent_name}.agent.root_agent'," f" '{actual_agent_name}.root_agent' and" - f" '{actual_agent_name}/root_agent.yaml'.\n\nExpected directory" - f" structure:\n /\n {actual_agent_name}/\n " - " agent.py (with root_agent) OR\n root_agent.yaml\n\nThen run:" - f" adk web \n\nEnsure '{agents_dir}/{actual_agent_name}' is" - " structured correctly, an .env file can be loaded if present, and a" - f" root_agent is exposed.{hint}" + f" '{actual_agent_name}{os.sep}root_agent.yaml'.\n\nExpected directory" + f" structure:\n {os.sep}\n " + f" {actual_agent_name}{os.sep}\n agent.py (with root_agent) OR\n " + " root_agent.yaml\n\nThen run: adk web \n\nEnsure" + f" '{os.path.join(agents_dir, actual_agent_name)}' is structured" + " correctly, an .env file can be loaded if present, and a root_agent" + f" is exposed.{hint}" ) def _record_origin_metadata( diff --git a/tests/unittests/agents/test_agent_config.py b/tests/unittests/agents/test_agent_config.py index c2300f5f5d..3d8e9209f9 100644 --- a/tests/unittests/agents/test_agent_config.py +++ b/tests/unittests/agents/test_agent_config.py @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ntpath +import os from pathlib import Path +from textwrap import dedent from typing import Literal from typing import Type +from unittest import mock from google.adk.agents import config_agent_utils from google.adk.agents.agent_config import AgentConfig from google.adk.agents.base_agent import BaseAgent from google.adk.agents.base_agent_config import BaseAgentConfig +from google.adk.agents.common_configs import AgentRefConfig from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.loop_agent import LoopAgent from google.adk.agents.parallel_agent import ParallelAgent @@ -280,3 +285,91 @@ class MyCustomAgentConfig(BaseAgentConfig): config.root.model_dump() ) assert my_custom_config.other_field == "other value" + + +@pytest.mark.parametrize( + ("config_rel_path", "child_rel_path", "child_name", "instruction"), + [ + ( + Path("main.yaml"), + Path("sub_agents/child.yaml"), + "child_agent", + "I am a child agent", + ), + ( + Path("level1/level2/nested_main.yaml"), + Path("sub/nested_child.yaml"), + "nested_child", + "I am nested", + ), + ], +) +def test_resolve_agent_reference_resolves_relative_paths( + config_rel_path: Path, + child_rel_path: Path, + child_name: str, + instruction: str, + tmp_path: Path, +): + """Verify resolve_agent_reference resolves relative sub-agent paths.""" + config_file = tmp_path / config_rel_path + config_file.parent.mkdir(parents=True, exist_ok=True) + + child_config_path = config_file.parent / child_rel_path + child_config_path.parent.mkdir(parents=True, exist_ok=True) + child_config_path.write_text(dedent(f""" + agent_class: LlmAgent + name: {child_name} + model: gemini-2.0-flash + instruction: {instruction} + """).lstrip()) + + config_file.write_text(dedent(f""" + agent_class: LlmAgent + name: main_agent + model: gemini-2.0-flash + instruction: I am the main agent + sub_agents: + - config_path: {child_rel_path.as_posix()} + """).lstrip()) + + ref_config = AgentRefConfig(config_path=child_rel_path.as_posix()) + agent = config_agent_utils.resolve_agent_reference( + ref_config, str(config_file) + ) + + assert agent.name == child_name + + config_dir = os.path.dirname(str(config_file.resolve())) + assert config_dir == str(config_file.parent.resolve()) + + expected_child_path = os.path.join(config_dir, *child_rel_path.parts) + assert os.path.exists(expected_child_path) + + +def test_resolve_agent_reference_uses_windows_dirname(): + """Ensure Windows-style config references resolve via os.path.dirname.""" + ref_config = AgentRefConfig(config_path="sub\\child.yaml") + recorded: dict[str, str] = {} + + def fake_from_config(path: str): + recorded["path"] = path + return "sentinel" + + with ( + mock.patch.object( + config_agent_utils, + "from_config", + autospec=True, + side_effect=fake_from_config, + ), + mock.patch.object(config_agent_utils.os, "path", ntpath), + ): + referencing = r"C:\workspace\agents\main.yaml" + result = config_agent_utils.resolve_agent_reference(ref_config, referencing) + + expected_path = ntpath.join( + ntpath.dirname(referencing), ref_config.config_path + ) + assert result == "sentinel" + assert recorded["path"] == expected_path diff --git a/tests/unittests/cli/utils/test_agent_loader.py b/tests/unittests/cli/utils/test_agent_loader.py index 5c66160aed..4950fecbd3 100644 --- a/tests/unittests/cli/utils/test_agent_loader.py +++ b/tests/unittests/cli/utils/test_agent_loader.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ntpath import os from pathlib import Path +from pathlib import PureWindowsPath +import re import sys import tempfile from textwrap import dedent +from google.adk.cli.utils import agent_loader as agent_loader_module from google.adk.cli.utils.agent_loader import AgentLoader from pydantic import ValidationError import pytest @@ -280,6 +284,43 @@ def test_load_multiple_different_agents(self): assert agent2 is not agent3 assert agent1.agent_id != agent2.agent_id != agent3.agent_id + def test_error_messages_use_os_sep_consistently(self): + """Verify error messages use os.sep instead of hardcoded '/'.""" + del self + with tempfile.TemporaryDirectory() as temp_dir: + loader = AgentLoader(temp_dir) + agent_name = "missing_agent" + + expected_path = os.path.join(temp_dir, agent_name) + + with pytest.raises(ValueError) as exc_info: + loader.load_agent(agent_name) + + exc_info.match(re.escape(expected_path)) + exc_info.match(re.escape(f"{agent_name}{os.sep}root_agent.yaml")) + exc_info.match(re.escape(f"{os.sep}")) + + def test_agent_loader_with_mocked_windows_path(self, monkeypatch): + """Mock Path() to simulate Windows behavior and catch regressions. + + REGRESSION TEST: Fails with rstrip('/'), passes with str(Path()). + """ + del self + windows_path = "C:\\Users\\dev\\agents\\" + + with monkeypatch.context() as m: + m.setattr( + agent_loader_module, + "Path", + lambda path_str: PureWindowsPath(path_str), + ) + loader = AgentLoader(windows_path) + + expected = str(PureWindowsPath(windows_path)) + assert loader.agents_dir == expected + assert not loader.agents_dir.endswith("\\") + assert not loader.agents_dir.endswith("/") + def test_agent_not_found_error(self): """Test that appropriate error is raised when agent is not found.""" with tempfile.TemporaryDirectory() as temp_dir: From 4eb2a11403d0625f2dbd6add7655aa884ef7d599 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Mon, 24 Nov 2025 16:17:48 -0800 Subject: [PATCH 29/63] fix: fix bug where remote a2a agent wasn't using its a2a part converter PiperOrigin-RevId: 836399603 --- src/google/adk/agents/remote_a2a_agent.py | 14 +++++++--- .../unittests/agents/test_remote_a2a_agent.py | 27 ++++++++++++++++--- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 7b6ff5cdd9..5d42730937 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -417,7 +417,9 @@ async def _handle_a2a_response( # This is the initial response for a streaming task or the complete # response for a non-streaming task, which is the full task state. # We process this to get the initial message. - event = convert_a2a_task_to_event(task, self.name, ctx) + event = convert_a2a_task_to_event( + task, self.name, ctx, self._a2a_part_converter + ) # for streaming task, we update the event with the task status. # We update the event as Thought updates. if task and task.status and task.status.state == TaskState.submitted: @@ -429,7 +431,7 @@ async def _handle_a2a_response( ): # This is a streaming task status update with a message. event = convert_a2a_message_to_event( - update.status.message, self.name, ctx + update.status.message, self.name, ctx, self._a2a_part_converter ) if event.content and update.status.state in [ TaskState.submitted, @@ -447,7 +449,9 @@ async def _handle_a2a_response( # signals: # 1. append: True for partial updates, False for full updates. # 2. last_chunk: True for full updates, False for partial updates. - event = convert_a2a_task_to_event(task, self.name, ctx) + event = convert_a2a_task_to_event( + task, self.name, ctx, self._a2a_part_converter + ) else: # This is a streaming update without a message (e.g. status change) # or a partial artifact update. We don't emit an event for these @@ -463,7 +467,9 @@ async def _handle_a2a_response( # Otherwise, it's a regular A2AMessage for non-streaming responses. elif isinstance(a2a_response, A2AMessage): - event = convert_a2a_message_to_event(a2a_response, self.name, ctx) + event = convert_a2a_message_to_event( + a2a_response, self.name, ctx, self._a2a_part_converter + ) event.custom_metadata = event.custom_metadata or {} if a2a_response.context_id: diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index 561a381870..fd722abf3f 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -723,6 +723,7 @@ async def test_handle_a2a_response_success_with_message(self): mock_a2a_message, self.agent.name, self.mock_context, + self.mock_a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None @@ -760,6 +761,7 @@ async def test_handle_a2a_response_with_task_completed_and_no_update(self): mock_a2a_task, self.agent.name, self.mock_context, + self.mock_a2a_part_converter, ) # Check the parts are not updated as Thought assert result.content.parts[0].thought is None @@ -864,6 +866,7 @@ async def test_handle_a2a_response_with_task_submitted_and_no_update(self): mock_a2a_task, self.agent.name, self.mock_context, + self.mock_a2a_part_converter, ) # Check the parts are updated as Thought assert result.content.parts[0].thought is True @@ -909,6 +912,7 @@ async def test_handle_a2a_response_with_task_status_update_with_message(self): mock_a2a_message, self.agent.name, self.mock_context, + self.mock_a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None @@ -954,6 +958,7 @@ async def test_handle_a2a_response_with_task_status_working_update_with_message( mock_a2a_message, self.agent.name, self.mock_context, + self.mock_a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None @@ -1009,7 +1014,10 @@ async def test_handle_a2a_response_with_artifact_update(self): assert result == mock_event mock_convert.assert_called_once_with( - mock_a2a_task, self.agent.name, self.mock_context + mock_a2a_task, + self.agent.name, + self.mock_context, + self.agent._a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None @@ -1039,6 +1047,8 @@ class TestRemoteA2aAgentMessageHandlingFromFactory: def setup_method(self): """Setup test fixtures.""" + self.mock_a2a_part_converter = Mock() + self.agent_card = create_test_agent_card() self.agent = RemoteA2aAgent( name="test_agent", @@ -1046,6 +1056,7 @@ def setup_method(self): a2a_client_factory=ClientFactory( config=ClientConfig(httpx_client=httpx.AsyncClient()), ), + a2a_part_converter=self.mock_a2a_part_converter, ) # Mock session and context @@ -1173,7 +1184,10 @@ async def test_handle_a2a_response_success_with_message(self): assert result == mock_event mock_convert.assert_called_once_with( - mock_a2a_message, self.agent.name, self.mock_context + mock_a2a_message, + self.agent.name, + self.mock_context, + self.mock_a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None @@ -1211,6 +1225,7 @@ async def test_handle_a2a_response_with_task_completed_and_no_update(self): mock_a2a_task, self.agent.name, self.mock_context, + self.mock_a2a_part_converter, ) # Check the parts are not updated as Thought assert result.content.parts[0].thought is None @@ -1251,6 +1266,7 @@ async def test_handle_a2a_response_with_task_submitted_and_no_update(self): mock_a2a_task, self.agent.name, self.mock_context, + self.agent._a2a_part_converter, ) # Check the parts are updated as Thought assert result.content.parts[0].thought is True @@ -1296,6 +1312,7 @@ async def test_handle_a2a_response_with_task_status_update_with_message(self): mock_a2a_message, self.agent.name, self.mock_context, + self.agent._a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None @@ -1341,6 +1358,7 @@ async def test_handle_a2a_response_with_task_status_working_update_with_message( mock_a2a_message, self.agent.name, self.mock_context, + self.agent._a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None @@ -1396,7 +1414,10 @@ async def test_handle_a2a_response_with_artifact_update(self): assert result == mock_event mock_convert.assert_called_once_with( - mock_a2a_task, self.agent.name, self.mock_context + mock_a2a_task, + self.agent.name, + self.mock_context, + self.agent._a2a_part_converter, ) # Check that metadata was added assert result.custom_metadata is not None From 728abe4d81c0d8d407aa196ed446baff574bfcd6 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Mon, 24 Nov 2025 16:48:17 -0800 Subject: [PATCH 30/63] feat(agents): Add warning for duplicate sub-agent names Co-authored-by: Shangjie Chen PiperOrigin-RevId: 836409638 --- src/google/adk/agents/base_agent.py | 42 +++++++++ tests/unittests/agents/test_base_agent.py | 107 ++++++++++++++++++++++ 2 files changed, 149 insertions(+) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index a644cb8b90..e15f9af981 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -15,6 +15,7 @@ from __future__ import annotations import inspect +import logging from typing import Any from typing import AsyncGenerator from typing import Awaitable @@ -49,6 +50,8 @@ if TYPE_CHECKING: from .invocation_context import InvocationContext +logger = logging.getLogger('google_adk.' + __name__) + _SingleAgentCallback: TypeAlias = Callable[ [CallbackContext], Union[Awaitable[Optional[types.Content]], Optional[types.Content]], @@ -563,6 +566,45 @@ def validate_name(cls, value: str): ) return value + @field_validator('sub_agents', mode='after') + @classmethod + def validate_sub_agents_unique_names( + cls, value: list[BaseAgent] + ) -> list[BaseAgent]: + """Validates that all sub-agents have unique names. + + Args: + value: The list of sub-agents to validate. + + Returns: + The validated list of sub-agents. + + """ + if not value: + return value + + seen_names: set[str] = set() + duplicates: set[str] = set() + + for sub_agent in value: + name = sub_agent.name + if name in seen_names: + duplicates.add(name) + else: + seen_names.add(name) + + if duplicates: + duplicate_names_str = ', '.join( + f'`{name}`' for name in sorted(duplicates) + ) + logger.warning( + 'Found duplicate sub-agent names: %s. ' + 'All sub-agents must have unique names.', + duplicate_names_str, + ) + + return value + def __set_parent_agent_for_sub_agents(self) -> BaseAgent: for sub_agent in self.sub_agents: if sub_agent.parent_agent is not None: diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index 663179f670..259bdd51c2 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -16,6 +16,7 @@ from enum import Enum from functools import partial +import logging from typing import AsyncGenerator from typing import List from typing import Optional @@ -854,6 +855,112 @@ def test_set_parent_agent_for_sub_agent_twice( ) +def test_validate_sub_agents_unique_names_single_duplicate( + request: pytest.FixtureRequest, + caplog: pytest.LogCaptureFixture, +): + """Test that duplicate sub-agent names logs a warning.""" + duplicate_name = f'{request.function.__name__}_duplicate_agent' + sub_agent_1 = _TestingAgent(name=duplicate_name) + sub_agent_2 = _TestingAgent(name=duplicate_name) + + with caplog.at_level(logging.WARNING): + _ = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=[sub_agent_1, sub_agent_2], + ) + assert f'Found duplicate sub-agent names: `{duplicate_name}`' in caplog.text + + +def test_validate_sub_agents_unique_names_multiple_duplicates( + request: pytest.FixtureRequest, + caplog: pytest.LogCaptureFixture, +): + """Test that multiple duplicate sub-agent names are all reported.""" + duplicate_name_1 = f'{request.function.__name__}_duplicate_1' + duplicate_name_2 = f'{request.function.__name__}_duplicate_2' + + sub_agents = [ + _TestingAgent(name=duplicate_name_1), + _TestingAgent(name=f'{request.function.__name__}_unique'), + _TestingAgent(name=duplicate_name_1), # First duplicate + _TestingAgent(name=duplicate_name_2), + _TestingAgent(name=duplicate_name_2), # Second duplicate + ] + + with caplog.at_level(logging.WARNING): + _ = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=sub_agents, + ) + + # Verify each duplicate name appears exactly once in the error message + assert caplog.text.count(duplicate_name_1) == 1 + assert caplog.text.count(duplicate_name_2) == 1 + # Verify both duplicate names are present + assert duplicate_name_1 in caplog.text + assert duplicate_name_2 in caplog.text + + +def test_validate_sub_agents_unique_names_triple_duplicate( + request: pytest.FixtureRequest, + caplog: pytest.LogCaptureFixture, +): + """Test that a name appearing three times is reported only once.""" + duplicate_name = f'{request.function.__name__}_triple_duplicate' + + sub_agents = [ + _TestingAgent(name=duplicate_name), + _TestingAgent(name=f'{request.function.__name__}_unique'), + _TestingAgent(name=duplicate_name), # Second occurrence + _TestingAgent(name=duplicate_name), # Third occurrence + ] + + with caplog.at_level(logging.WARNING): + _ = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=sub_agents, + ) + + # Verify the duplicate name appears exactly once in the error message + # (not three times even though it appears three times in the list) + assert caplog.text.count(duplicate_name) == 1 + assert duplicate_name in caplog.text + + +def test_validate_sub_agents_unique_names_no_duplicates( + request: pytest.FixtureRequest, +): + """Test that unique sub-agent names pass validation.""" + sub_agents = [ + _TestingAgent(name=f'{request.function.__name__}_sub_agent_1'), + _TestingAgent(name=f'{request.function.__name__}_sub_agent_2'), + _TestingAgent(name=f'{request.function.__name__}_sub_agent_3'), + ] + + parent = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=sub_agents, + ) + + assert len(parent.sub_agents) == 3 + assert parent.sub_agents[0].name == f'{request.function.__name__}_sub_agent_1' + assert parent.sub_agents[1].name == f'{request.function.__name__}_sub_agent_2' + assert parent.sub_agents[2].name == f'{request.function.__name__}_sub_agent_3' + + +def test_validate_sub_agents_unique_names_empty_list( + request: pytest.FixtureRequest, +): + """Test that empty sub-agents list passes validation.""" + parent = _TestingAgent( + name=f'{request.function.__name__}_parent', + sub_agents=[], + ) + + assert len(parent.sub_agents) == 0 + + if __name__ == '__main__': pytest.main([__file__]) From 4a42d0d9d81b7aab98371427f70a7707dbfb8bc4 Mon Sep 17 00:00:00 2001 From: qieqieplus Date: Mon, 24 Nov 2025 16:51:06 -0800 Subject: [PATCH 31/63] feat: Add enum constraint to `agent_name` for `transfer_to_agent` Merge https://github.com/google/adk-python/pull/2437 Current implementation of `transfer_to_agent` doesn't enforce strict constraints on agent names, we could use JSON Schema's enum definition to implement stricter constraints. Co-authored-by: Xuan Yang COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/2437 from qieqieplus:main 052e8e73b9d61c0998573a2077f15864873d0dd7 PiperOrigin-RevId: 836410397 --- .../adk/flows/llm_flows/agent_transfer.py | 35 ++-- src/google/adk/tools/__init__.py | 5 + .../adk/tools/transfer_to_agent_tool.py | 60 +++++++ ...test_agent_transfer_system_instructions.py | 39 +++-- .../tools/test_build_function_declaration.py | 16 ++ .../tools/test_transfer_to_agent_tool.py | 164 ++++++++++++++++++ 6 files changed, 286 insertions(+), 33 deletions(-) create mode 100644 tests/unittests/tools/test_transfer_to_agent_tool.py diff --git a/src/google/adk/flows/llm_flows/agent_transfer.py b/src/google/adk/flows/llm_flows/agent_transfer.py index 037b8c6d50..ea144bf75d 100644 --- a/src/google/adk/flows/llm_flows/agent_transfer.py +++ b/src/google/adk/flows/llm_flows/agent_transfer.py @@ -24,9 +24,8 @@ from ...agents.invocation_context import InvocationContext from ...events.event import Event from ...models.llm_request import LlmRequest -from ...tools.function_tool import FunctionTool from ...tools.tool_context import ToolContext -from ...tools.transfer_to_agent_tool import transfer_to_agent +from ...tools.transfer_to_agent_tool import TransferToAgentTool from ._base_llm_processor import BaseLlmRequestProcessor if typing.TYPE_CHECKING: @@ -50,13 +49,18 @@ async def run_async( if not transfer_targets: return + transfer_to_agent_tool = TransferToAgentTool( + agent_names=[agent.name for agent in transfer_targets] + ) + llm_request.append_instructions([ _build_target_agents_instructions( - invocation_context.agent, transfer_targets + transfer_to_agent_tool.name, + invocation_context.agent, + transfer_targets, ) ]) - transfer_to_agent_tool = FunctionTool(func=transfer_to_agent) tool_context = ToolContext(invocation_context) await transfer_to_agent_tool.process_llm_request( tool_context=tool_context, llm_request=llm_request @@ -80,10 +84,13 @@ def _build_target_agents_info(target_agent: BaseAgent) -> str: def _build_target_agents_instructions( - agent: LlmAgent, target_agents: list[BaseAgent] + tool_name: str, + agent: LlmAgent, + target_agents: list[BaseAgent], ) -> str: # Build list of available agent names for the NOTE - # target_agents already includes parent agent if applicable, so no need to add it again + # target_agents already includes parent agent if applicable, + # so no need to add it again available_agent_names = [target_agent.name for target_agent in target_agents] # Sort for consistency @@ -101,15 +108,16 @@ def _build_target_agents_instructions( _build_target_agents_info(target_agent) for target_agent in target_agents ])} -If you are the best to answer the question according to your description, you -can answer it. +If you are the best to answer the question according to your description, +you can answer it. If another agent is better for answering the question according to its -description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the -question to that agent. When transferring, do not generate any text other than -the function call. +description, call `{tool_name}` function to transfer the question to that +agent. When transferring, do not generate any text other than the function +call. -**NOTE**: the only available agents for `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function are {formatted_agent_names}. +**NOTE**: the only available agents for `{tool_name}` function are +{formatted_agent_names}. """ if agent.parent_agent and not agent.disallow_transfer_to_parent: @@ -119,9 +127,6 @@ def _build_target_agents_instructions( return si -_TRANSFER_TO_AGENT_FUNCTION_NAME = transfer_to_agent.__name__ - - def _get_transfer_targets(agent: LlmAgent) -> list[BaseAgent]: from ...agents.llm_agent import LlmAgent diff --git a/src/google/adk/tools/__init__.py b/src/google/adk/tools/__init__.py index 1777bd93c5..d359abb728 100644 --- a/src/google/adk/tools/__init__.py +++ b/src/google/adk/tools/__init__.py @@ -37,6 +37,7 @@ from .preload_memory_tool import preload_memory_tool as preload_memory from .tool_context import ToolContext from .transfer_to_agent_tool import transfer_to_agent + from .transfer_to_agent_tool import TransferToAgentTool from .url_context_tool import url_context from .vertex_ai_search_tool import VertexAiSearchTool @@ -75,6 +76,10 @@ 'preload_memory': ('.preload_memory_tool', 'preload_memory_tool'), 'ToolContext': ('.tool_context', 'ToolContext'), 'transfer_to_agent': ('.transfer_to_agent_tool', 'transfer_to_agent'), + 'TransferToAgentTool': ( + '.transfer_to_agent_tool', + 'TransferToAgentTool', + ), 'url_context': ('.url_context_tool', 'url_context'), 'VertexAiSearchTool': ('.vertex_ai_search_tool', 'VertexAiSearchTool'), 'MCPToolset': ('.mcp_tool.mcp_toolset', 'MCPToolset'), diff --git a/src/google/adk/tools/transfer_to_agent_tool.py b/src/google/adk/tools/transfer_to_agent_tool.py index 99ee234b30..2124e6aab9 100644 --- a/src/google/adk/tools/transfer_to_agent_tool.py +++ b/src/google/adk/tools/transfer_to_agent_tool.py @@ -14,6 +14,12 @@ from __future__ import annotations +from typing import Optional + +from google.genai import types +from typing_extensions import override + +from .function_tool import FunctionTool from .tool_context import ToolContext @@ -23,7 +29,61 @@ def transfer_to_agent(agent_name: str, tool_context: ToolContext) -> None: This tool hands off control to another agent when it's more suitable to answer the user's question according to the agent's description. + Note: + For most use cases, you should use TransferToAgentTool instead of this + function directly. TransferToAgentTool provides additional enum constraints + that prevent LLMs from hallucinating invalid agent names. + Args: agent_name: the agent name to transfer to. """ tool_context.actions.transfer_to_agent = agent_name + + +class TransferToAgentTool(FunctionTool): + """A specialized FunctionTool for agent transfer with enum constraints. + + This tool enhances the base transfer_to_agent function by adding JSON Schema + enum constraints to the agent_name parameter. This prevents LLMs from + hallucinating invalid agent names by restricting choices to only valid agents. + + Attributes: + agent_names: List of valid agent names that can be transferred to. + """ + + def __init__( + self, + agent_names: list[str], + ): + """Initialize the TransferToAgentTool. + + Args: + agent_names: List of valid agent names that can be transferred to. + """ + super().__init__(func=transfer_to_agent) + self._agent_names = agent_names + + @override + def _get_declaration(self) -> Optional[types.FunctionDeclaration]: + """Add enum constraint to the agent_name parameter. + + Returns: + FunctionDeclaration with enum constraint on agent_name parameter. + """ + function_decl = super()._get_declaration() + if not function_decl: + return function_decl + + # Handle parameters (types.Schema object) + if function_decl.parameters: + agent_name_schema = function_decl.parameters.properties.get('agent_name') + if agent_name_schema: + agent_name_schema.enum = self._agent_names + + # Handle parameters_json_schema (dict) + if function_decl.parameters_json_schema: + properties = function_decl.parameters_json_schema.get('properties', {}) + if 'agent_name' in properties: + properties['agent_name']['enum'] = self._agent_names + + return function_decl diff --git a/tests/unittests/flows/llm_flows/test_agent_transfer_system_instructions.py b/tests/unittests/flows/llm_flows/test_agent_transfer_system_instructions.py index be97a627a1..b180a589cb 100644 --- a/tests/unittests/flows/llm_flows/test_agent_transfer_system_instructions.py +++ b/tests/unittests/flows/llm_flows/test_agent_transfer_system_instructions.py @@ -126,15 +126,16 @@ async def test_agent_transfer_includes_sorted_agent_names_in_system_instructions Agent description: Peer agent -If you are the best to answer the question according to your description, you -can answer it. +If you are the best to answer the question according to your description, +you can answer it. If another agent is better for answering the question according to its -description, call `transfer_to_agent` function to transfer the -question to that agent. When transferring, do not generate any text other than -the function call. +description, call `transfer_to_agent` function to transfer the question to that +agent. When transferring, do not generate any text other than the function +call. -**NOTE**: the only available agents for `transfer_to_agent` function are `a_agent`, `m_agent`, `parent_agent`, `peer_agent`, `z_agent`. +**NOTE**: the only available agents for `transfer_to_agent` function are +`a_agent`, `m_agent`, `parent_agent`, `peer_agent`, `z_agent`. If neither you nor the other agents are best for the question, transfer to your parent agent parent_agent.""" @@ -189,15 +190,16 @@ async def test_agent_transfer_system_instructions_without_parent(): Agent description: Second sub-agent -If you are the best to answer the question according to your description, you -can answer it. +If you are the best to answer the question according to your description, +you can answer it. If another agent is better for answering the question according to its -description, call `transfer_to_agent` function to transfer the -question to that agent. When transferring, do not generate any text other than -the function call. +description, call `transfer_to_agent` function to transfer the question to that +agent. When transferring, do not generate any text other than the function +call. -**NOTE**: the only available agents for `transfer_to_agent` function are `agent1`, `agent2`.""" +**NOTE**: the only available agents for `transfer_to_agent` function are +`agent1`, `agent2`.""" assert expected_content in instructions @@ -248,15 +250,16 @@ async def test_agent_transfer_simplified_parent_instructions(): Agent description: Parent agent -If you are the best to answer the question according to your description, you -can answer it. +If you are the best to answer the question according to your description, +you can answer it. If another agent is better for answering the question according to its -description, call `transfer_to_agent` function to transfer the -question to that agent. When transferring, do not generate any text other than -the function call. +description, call `transfer_to_agent` function to transfer the question to that +agent. When transferring, do not generate any text other than the function +call. -**NOTE**: the only available agents for `transfer_to_agent` function are `parent_agent`, `sub_agent`. +**NOTE**: the only available agents for `transfer_to_agent` function are +`parent_agent`, `sub_agent`. If neither you nor the other agents are best for the question, transfer to your parent agent parent_agent.""" diff --git a/tests/unittests/tools/test_build_function_declaration.py b/tests/unittests/tools/test_build_function_declaration.py index f57c3d3838..8a562c7cf4 100644 --- a/tests/unittests/tools/test_build_function_declaration.py +++ b/tests/unittests/tools/test_build_function_declaration.py @@ -411,3 +411,19 @@ def transfer_to_agent(agent_name: str, tool_context: ToolContext): # Changed: Now uses Any type instead of NULL for no return annotation assert function_decl.response is not None assert function_decl.response.type is None # Any type maps to None in schema + + +def test_transfer_to_agent_tool_with_enum_constraint(): + """Test TransferToAgentTool adds enum constraint to agent_name.""" + from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool + + agent_names = ['agent_a', 'agent_b', 'agent_c'] + tool = TransferToAgentTool(agent_names=agent_names) + + function_decl = tool._get_declaration() + + assert function_decl.name == 'transfer_to_agent' + assert function_decl.parameters.type == 'OBJECT' + assert function_decl.parameters.properties['agent_name'].type == 'STRING' + assert function_decl.parameters.properties['agent_name'].enum == agent_names + assert 'tool_context' not in function_decl.parameters.properties diff --git a/tests/unittests/tools/test_transfer_to_agent_tool.py b/tests/unittests/tools/test_transfer_to_agent_tool.py new file mode 100644 index 0000000000..14b7b3abea --- /dev/null +++ b/tests/unittests/tools/test_transfer_to_agent_tool.py @@ -0,0 +1,164 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for TransferToAgentTool enum constraint functionality.""" + +from unittest.mock import patch + +from google.adk.tools.function_tool import FunctionTool +from google.adk.tools.transfer_to_agent_tool import TransferToAgentTool +from google.genai import types + + +def test_transfer_to_agent_tool_enum_constraint(): + """Test that TransferToAgentTool adds enum constraint to agent_name.""" + agent_names = ['agent_a', 'agent_b', 'agent_c'] + tool = TransferToAgentTool(agent_names=agent_names) + + decl = tool._get_declaration() + + assert decl is not None + assert decl.name == 'transfer_to_agent' + assert decl.parameters is not None + assert decl.parameters.type == types.Type.OBJECT + assert 'agent_name' in decl.parameters.properties + + agent_name_schema = decl.parameters.properties['agent_name'] + assert agent_name_schema.type == types.Type.STRING + assert agent_name_schema.enum == agent_names + + # Verify that agent_name is marked as required + assert decl.parameters.required == ['agent_name'] + + +def test_transfer_to_agent_tool_single_agent(): + """Test TransferToAgentTool with a single agent.""" + tool = TransferToAgentTool(agent_names=['single_agent']) + + decl = tool._get_declaration() + + assert decl is not None + agent_name_schema = decl.parameters.properties['agent_name'] + assert agent_name_schema.enum == ['single_agent'] + + +def test_transfer_to_agent_tool_multiple_agents(): + """Test TransferToAgentTool with multiple agents.""" + agent_names = ['agent_1', 'agent_2', 'agent_3', 'agent_4', 'agent_5'] + tool = TransferToAgentTool(agent_names=agent_names) + + decl = tool._get_declaration() + + assert decl is not None + agent_name_schema = decl.parameters.properties['agent_name'] + assert agent_name_schema.enum == agent_names + assert len(agent_name_schema.enum) == 5 + + +def test_transfer_to_agent_tool_empty_list(): + """Test TransferToAgentTool with an empty agent list.""" + tool = TransferToAgentTool(agent_names=[]) + + decl = tool._get_declaration() + + assert decl is not None + agent_name_schema = decl.parameters.properties['agent_name'] + assert agent_name_schema.enum == [] + + +def test_transfer_to_agent_tool_preserves_description(): + """Test that TransferToAgentTool preserves the original description.""" + tool = TransferToAgentTool(agent_names=['agent_a', 'agent_b']) + + decl = tool._get_declaration() + + assert decl is not None + assert decl.description is not None + assert 'Transfer the question to another agent' in decl.description + + +def test_transfer_to_agent_tool_preserves_parameter_type(): + """Test that TransferToAgentTool preserves the parameter type.""" + tool = TransferToAgentTool(agent_names=['agent_a']) + + decl = tool._get_declaration() + + assert decl is not None + agent_name_schema = decl.parameters.properties['agent_name'] + # Should still be a string type, just with enum constraint + assert agent_name_schema.type == types.Type.STRING + + +def test_transfer_to_agent_tool_no_extra_parameters(): + """Test that TransferToAgentTool doesn't add extra parameters.""" + tool = TransferToAgentTool(agent_names=['agent_a']) + + decl = tool._get_declaration() + + assert decl is not None + # Should only have agent_name parameter (tool_context is ignored) + assert len(decl.parameters.properties) == 1 + assert 'agent_name' in decl.parameters.properties + assert 'tool_context' not in decl.parameters.properties + + +def test_transfer_to_agent_tool_maintains_inheritance(): + """Test that TransferToAgentTool inherits from FunctionTool correctly.""" + tool = TransferToAgentTool(agent_names=['agent_a']) + + assert isinstance(tool, FunctionTool) + assert hasattr(tool, '_get_declaration') + assert hasattr(tool, 'process_llm_request') + + +def test_transfer_to_agent_tool_handles_parameters_json_schema(): + """Test that TransferToAgentTool handles parameters_json_schema format.""" + agent_names = ['agent_x', 'agent_y', 'agent_z'] + + # Create a mock FunctionDeclaration with parameters_json_schema + mock_decl = type('MockDecl', (), {})() + mock_decl.parameters = None # No Schema object + mock_decl.parameters_json_schema = { + 'type': 'object', + 'properties': { + 'agent_name': { + 'type': 'string', + 'description': 'Agent name to transfer to', + } + }, + 'required': ['agent_name'], + } + + # Temporarily patch FunctionTool._get_declaration + with patch.object( + FunctionTool, + '_get_declaration', + return_value=mock_decl, + ): + tool = TransferToAgentTool(agent_names=agent_names) + result = tool._get_declaration() + + # Verify enum was added to parameters_json_schema + assert result.parameters_json_schema is not None + assert 'agent_name' in result.parameters_json_schema['properties'] + assert ( + result.parameters_json_schema['properties']['agent_name']['enum'] + == agent_names + ) + assert ( + result.parameters_json_schema['properties']['agent_name']['type'] + == 'string' + ) + # Verify required field is preserved + assert result.parameters_json_schema['required'] == ['agent_name'] From b331d97dfb14e2cc8e3783f56d04fb14b2236fe3 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Mon, 24 Nov 2025 17:14:15 -0800 Subject: [PATCH 32/63] docs: Remove the `list_unlabeled_issues` tool from the issue triaging agent Co-authored-by: Xuan Yang PiperOrigin-RevId: 836416597 --- .../samples/adk_triaging_agent/agent.py | 33 ------------------- .../samples/adk_triaging_agent/main.py | 3 +- 2 files changed, 2 insertions(+), 34 deletions(-) diff --git a/contributing/samples/adk_triaging_agent/agent.py b/contributing/samples/adk_triaging_agent/agent.py index 9504e72dff..167eb3a616 100644 --- a/contributing/samples/adk_triaging_agent/agent.py +++ b/contributing/samples/adk_triaging_agent/agent.py @@ -78,38 +78,6 @@ APPROVAL_INSTRUCTION = "Only label them when the user approves the labeling!" -def list_unlabeled_issues(issue_count: int) -> dict[str, Any]: - """List most recent `issue_count` number of unlabeled issues in the repo. - - Args: - issue_count: number of issues to return - - Returns: - The status of this request, with a list of issues when successful. - """ - url = f"{GITHUB_BASE_URL}/search/issues" - query = f"repo:{OWNER}/{REPO} is:open is:issue no:label" - params = { - "q": query, - "sort": "created", - "order": "desc", - "per_page": issue_count, - "page": 1, - } - - try: - response = get_request(url, params) - except requests.exceptions.RequestException as e: - return error_response(f"Error: {e}") - issues = response.get("items", None) - - unlabeled_issues = [] - for issue in issues: - if not issue.get("labels", None): - unlabeled_issues.append(issue) - return {"status": "success", "issues": unlabeled_issues} - - def list_planned_untriaged_issues(issue_count: int) -> dict[str, Any]: """List planned issues without component labels (e.g., core, tools, etc.). @@ -276,7 +244,6 @@ def change_issue_type(issue_number: int, issue_type: str) -> dict[str, Any]: - the owner of the label if you assign the issue to an owner """, tools=[ - list_unlabeled_issues, list_planned_untriaged_issues, add_label_and_owner_to_issue, change_issue_type, diff --git a/contributing/samples/adk_triaging_agent/main.py b/contributing/samples/adk_triaging_agent/main.py index f608a696c0..f24302ac4b 100644 --- a/contributing/samples/adk_triaging_agent/main.py +++ b/contributing/samples/adk_triaging_agent/main.py @@ -144,7 +144,8 @@ async def main(): f" most recent {issue_count} planned issues that haven't been" " triaged yet (i.e., issues with 'planned' label but no component" " labels like 'core', 'tools', etc.). Then triage each of them by" - " applying appropriate component labels." + " applying appropriate component labels. If you cannot find any planned" + " issues, please don't try to triage any issues." ) response = await call_agent_async(runner, USER_ID, session.id, prompt) From c6e7d6b16a204c754ec62a5334897dbac5df2b01 Mon Sep 17 00:00:00 2001 From: Om Kute <130128563+omkute10@users.noreply.github.com> Date: Mon, 24 Nov 2025 17:29:18 -0800 Subject: [PATCH 33/63] feat(tools): Add debug logging to VertexAiSearchTool Merge https://github.com/google/adk-python/pull/3284 **Problem:** When debugging agents that utilize the `VertexAiSearchTool`, it's currently difficult to inspect the specific configuration parameters (datastore ID, engine ID, filter, max_results, etc.) being passed to the underlying Vertex AI Search API via the `LlmRequest`. This lack of visibility can hinder troubleshooting efforts related to tool configuration. **Solution:** This PR enhances the `VertexAiSearchTool` by adding a **debug-level log statement** within the `process_llm_request` method. This log precisely records the parameters being used for the Vertex AI Search configuration just before it's appended to the `LlmRequest`. This provides developers with crucial visibility into the tool's runtime behavior when debug logging is enabled, significantly improving the **debuggability** of agents using this tool. Corresponding unit tests were updated to rigorously verify this new logging output using `caplog`. Additionally, minor fixes were made to the tests to resolve Pydantic validation errors. Co-authored-by: Xuan Yang COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3284 from omkute10:feat/add-logging-vertex-search-tool 199c12bf00a57abe202401591088c0423b39b928 PiperOrigin-RevId: 836419886 --- src/google/adk/tools/vertex_ai_search_tool.py | 27 +++ .../tools/test_vertex_ai_search_tool.py | 155 ++++++++++++++++-- 2 files changed, 170 insertions(+), 12 deletions(-) diff --git a/src/google/adk/tools/vertex_ai_search_tool.py b/src/google/adk/tools/vertex_ai_search_tool.py index aff5be1552..e0c228be0e 100644 --- a/src/google/adk/tools/vertex_ai_search_tool.py +++ b/src/google/adk/tools/vertex_ai_search_tool.py @@ -14,6 +14,7 @@ from __future__ import annotations +import logging from typing import Optional from typing import TYPE_CHECKING @@ -25,6 +26,8 @@ from .base_tool import BaseTool from .tool_context import ToolContext +logger = logging.getLogger('google_adk.' + __name__) + if TYPE_CHECKING: from ..models import LlmRequest @@ -102,6 +105,30 @@ async def process_llm_request( ) llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] + + # Format data_store_specs concisely for logging + if self.data_store_specs: + spec_ids = [ + spec.data_store.split('/')[-1] if spec.data_store else 'unnamed' + for spec in self.data_store_specs + ] + specs_info = ( + f'{len(self.data_store_specs)} spec(s): [{", ".join(spec_ids)}]' + ) + else: + specs_info = None + + logger.debug( + 'Adding Vertex AI Search tool config to LLM request: ' + 'datastore=%s, engine=%s, filter=%s, max_results=%s, ' + 'data_store_specs=%s', + self.data_store_id, + self.search_engine_id, + self.filter, + self.max_results, + specs_info, + ) + llm_request.config.tools.append( types.Tool( retrieval=types.Retrieval( diff --git a/tests/unittests/tools/test_vertex_ai_search_tool.py b/tests/unittests/tools/test_vertex_ai_search_tool.py index 0df19288a3..1ec1572b90 100644 --- a/tests/unittests/tools/test_vertex_ai_search_tool.py +++ b/tests/unittests/tools/test_vertex_ai_search_tool.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.sequential_agent import SequentialAgent from google.adk.models.llm_request import LlmRequest @@ -24,6 +26,10 @@ from google.genai import types import pytest +VERTEX_SEARCH_TOOL_LOGGER_NAME = ( + 'google_adk.google.adk.tools.vertex_ai_search_tool' +) + async def _create_tool_context() -> ToolContext: session_service = InMemorySessionService() @@ -121,12 +127,34 @@ def test_init_with_data_store_id(self): tool = VertexAiSearchTool(data_store_id='test_data_store') assert tool.data_store_id == 'test_data_store' assert tool.search_engine_id is None + assert tool.data_store_specs is None def test_init_with_search_engine_id(self): """Test initialization with search engine ID.""" tool = VertexAiSearchTool(search_engine_id='test_search_engine') assert tool.search_engine_id == 'test_search_engine' assert tool.data_store_id is None + assert tool.data_store_specs is None + + def test_init_with_engine_and_specs(self): + """Test initialization with search engine ID and specs.""" + specs = [ + types.VertexAISearchDataStoreSpec( + dataStore=( + 'projects/p/locations/l/collections/c/dataStores/spec_store' + ) + ) + ] + engine_id = ( + 'projects/p/locations/l/collections/c/engines/test_search_engine' + ) + tool = VertexAiSearchTool( + search_engine_id=engine_id, + data_store_specs=specs, + ) + assert tool.search_engine_id == engine_id + assert tool.data_store_id is None + assert tool.data_store_specs == specs def test_init_with_neither_raises_error(self): """Test that initialization without either ID raises ValueError.""" @@ -146,10 +174,34 @@ def test_init_with_both_raises_error(self): data_store_id='test_data_store', search_engine_id='test_search_engine' ) + def test_init_with_specs_but_no_engine_raises_error(self): + """Test that specs without engine ID raises ValueError.""" + specs = [ + types.VertexAISearchDataStoreSpec( + dataStore=( + 'projects/p/locations/l/collections/c/dataStores/spec_store' + ) + ) + ] + with pytest.raises( + ValueError, + match=( + 'search_engine_id must be specified if data_store_specs is' + ' specified' + ), + ): + VertexAiSearchTool( + data_store_id='test_data_store', data_store_specs=specs + ) + @pytest.mark.asyncio - async def test_process_llm_request_with_simple_gemini_model(self): + async def test_process_llm_request_with_simple_gemini_model(self, caplog): """Test processing LLM request with simple Gemini model name.""" - tool = VertexAiSearchTool(data_store_id='test_data_store') + caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME) + + tool = VertexAiSearchTool( + data_store_id='test_data_store', filter='f', max_results=5 + ) tool_context = await _create_tool_context() llm_request = LlmRequest( @@ -162,17 +214,56 @@ async def test_process_llm_request_with_simple_gemini_model(self): assert llm_request.config.tools is not None assert len(llm_request.config.tools) == 1 - assert llm_request.config.tools[0].retrieval is not None - assert llm_request.config.tools[0].retrieval.vertex_ai_search is not None + retrieval_tool = llm_request.config.tools[0] + assert retrieval_tool.retrieval is not None + assert retrieval_tool.retrieval.vertex_ai_search is not None + assert ( + retrieval_tool.retrieval.vertex_ai_search.datastore == 'test_data_store' + ) + assert retrieval_tool.retrieval.vertex_ai_search.engine is None + assert retrieval_tool.retrieval.vertex_ai_search.filter == 'f' + assert retrieval_tool.retrieval.vertex_ai_search.max_results == 5 + + # Verify debug log + debug_records = [ + r + for r in caplog.records + if 'Adding Vertex AI Search tool config' in r.message + ] + assert len(debug_records) == 1 + log_message = debug_records[0].getMessage() + assert 'datastore=test_data_store' in log_message + assert 'engine=None' in log_message + assert 'filter=f' in log_message + assert 'max_results=5' in log_message + assert 'data_store_specs=None' in log_message @pytest.mark.asyncio - async def test_process_llm_request_with_path_based_gemini_model(self): + async def test_process_llm_request_with_path_based_gemini_model(self, caplog): """Test processing LLM request with path-based Gemini model name.""" - tool = VertexAiSearchTool(data_store_id='test_data_store') + caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME) + + specs = [ + types.VertexAISearchDataStoreSpec( + dataStore=( + 'projects/p/locations/l/collections/c/dataStores/spec_store' + ) + ) + ] + engine_id = 'projects/p/locations/l/collections/c/engines/test_engine' + tool = VertexAiSearchTool( + search_engine_id=engine_id, + data_store_specs=specs, + filter='f2', + max_results=10, + ) tool_context = await _create_tool_context() llm_request = LlmRequest( - model='projects/265104255505/locations/us-central1/publishers/google/models/gemini-2.0-flash-001', + model=( + 'projects/265104255505/locations/us-central1/publishers/' + 'google/models/gemini-2.0-flash-001' + ), config=types.GenerateContentConfig(), ) @@ -182,8 +273,28 @@ async def test_process_llm_request_with_path_based_gemini_model(self): assert llm_request.config.tools is not None assert len(llm_request.config.tools) == 1 - assert llm_request.config.tools[0].retrieval is not None - assert llm_request.config.tools[0].retrieval.vertex_ai_search is not None + retrieval_tool = llm_request.config.tools[0] + assert retrieval_tool.retrieval is not None + assert retrieval_tool.retrieval.vertex_ai_search is not None + assert retrieval_tool.retrieval.vertex_ai_search.datastore is None + assert retrieval_tool.retrieval.vertex_ai_search.engine == engine_id + assert retrieval_tool.retrieval.vertex_ai_search.filter == 'f2' + assert retrieval_tool.retrieval.vertex_ai_search.max_results == 10 + assert retrieval_tool.retrieval.vertex_ai_search.data_store_specs == specs + + # Verify debug log + debug_records = [ + r + for r in caplog.records + if 'Adding Vertex AI Search tool config' in r.message + ] + assert len(debug_records) == 1 + log_message = debug_records[0].getMessage() + assert 'datastore=None' in log_message + assert f'engine={engine_id}' in log_message + assert 'filter=f2' in log_message + assert 'max_results=10' in log_message + assert 'data_store_specs=1 spec(s): [spec_store]' in log_message @pytest.mark.asyncio async def test_process_llm_request_with_gemini_1_and_other_tools_raises_error( @@ -291,9 +402,11 @@ async def test_process_llm_request_with_path_based_non_gemini_model_raises_error @pytest.mark.asyncio async def test_process_llm_request_with_gemini_2_and_other_tools_succeeds( - self, + self, caplog ): """Test that Gemini 2.x with other tools succeeds.""" + caplog.set_level(logging.DEBUG, logger=VERTEX_SEARCH_TOOL_LOGGER_NAME) + tool = VertexAiSearchTool(data_store_id='test_data_store') tool_context = await _create_tool_context() @@ -316,5 +429,23 @@ async def test_process_llm_request_with_gemini_2_and_other_tools_succeeds( assert llm_request.config.tools is not None assert len(llm_request.config.tools) == 2 assert llm_request.config.tools[0] == existing_tool - assert llm_request.config.tools[1].retrieval is not None - assert llm_request.config.tools[1].retrieval.vertex_ai_search is not None + retrieval_tool = llm_request.config.tools[1] + assert retrieval_tool.retrieval is not None + assert retrieval_tool.retrieval.vertex_ai_search is not None + assert ( + retrieval_tool.retrieval.vertex_ai_search.datastore == 'test_data_store' + ) + + # Verify debug log + debug_records = [ + r + for r in caplog.records + if 'Adding Vertex AI Search tool config' in r.message + ] + assert len(debug_records) == 1 + log_message = debug_records[0].getMessage() + assert 'datastore=test_data_store' in log_message + assert 'engine=None' in log_message + assert 'filter=None' in log_message + assert 'max_results=None' in log_message + assert 'data_store_specs=None' in log_message From e6be5bc9c66f04f204cf1389d981fe94a714e13f Mon Sep 17 00:00:00 2001 From: Bastien Jacot-Guillarmod Date: Tue, 25 Nov 2025 04:44:28 -0800 Subject: [PATCH 34/63] fix: Add type annotations to Runner.__aenter__ PiperOrigin-RevId: 836614561 --- src/google/adk/runners.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 2bb0168928..db9828f66e 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -19,6 +19,7 @@ import logging from pathlib import Path import queue +import sys from typing import Any from typing import AsyncGenerator from typing import Callable @@ -985,16 +986,16 @@ async def run_debug( over session management, event streaming, and error handling. Args: - user_messages: Message(s) to send to the agent. Can be: - - Single string: "What is 2+2?" - - List of strings: ["Hello!", "What's my name?"] + user_messages: Message(s) to send to the agent. Can be: - Single string: + "What is 2+2?" - List of strings: ["Hello!", "What's my name?"] user_id: User identifier. Defaults to "debug_user_id". - session_id: Session identifier for conversation persistence. - Defaults to "debug_session_id". Reuse the same ID to continue a conversation. + session_id: Session identifier for conversation persistence. Defaults to + "debug_session_id". Reuse the same ID to continue a conversation. run_config: Optional configuration for the agent execution. - quiet: If True, suppresses console output. Defaults to False (output shown). - verbose: If True, shows detailed tool calls and responses. Defaults to False - for cleaner output showing only final agent responses. + quiet: If True, suppresses console output. Defaults to False (output + shown). + verbose: If True, shows detailed tool calls and responses. Defaults to + False for cleaner output showing only final agent responses. Returns: list[Event]: All events from all messages. @@ -1011,7 +1012,8 @@ async def run_debug( >>> await runner.run_debug(["Hello!", "What's my name?"]) Continue a debug session: - >>> await runner.run_debug("What did we discuss?") # Continues default session + >>> await runner.run_debug("What did we discuss?") # Continues default + session Separate debug sessions: >>> await runner.run_debug("Hi", user_id="alice", session_id="debug1") @@ -1353,7 +1355,12 @@ async def close(self): logger.info('Runner closed.') - async def __aenter__(self): + if sys.version_info < (3, 11): + Self = 'Runner' # pylint: disable=invalid-name + else: + from typing import Self # pylint: disable=g-import-not-at-top + + async def __aenter__(self) -> Self: """Async context manager entry.""" return self From d29261a3dc9c5a603feef27ea657c4a03bb8a089 Mon Sep 17 00:00:00 2001 From: Virtuoso633 Date: Tue, 25 Nov 2025 09:46:18 -0800 Subject: [PATCH 35/63] feat(models): Enable multi-provider support for Claude and LiteLLM Merges: https://github.com/google/adk-python/pull/2810 Co-authored-by: Xuan Yang PiperOrigin-RevId: 836706608 --- src/google/adk/models/__init__.py | 20 +++++++ src/google/adk/models/lite_llm.py | 14 ++++- src/google/adk/models/registry.py | 24 +++++++- .../unittests/agents/test_llm_agent_fields.py | 47 +++++++++++++++ tests/unittests/models/test_models.py | 58 ++++++++++++++++++- 5 files changed, 156 insertions(+), 7 deletions(-) diff --git a/src/google/adk/models/__init__.py b/src/google/adk/models/__init__.py index 9f3c2a2c48..1be0cc698e 100644 --- a/src/google/adk/models/__init__.py +++ b/src/google/adk/models/__init__.py @@ -33,3 +33,23 @@ LLMRegistry.register(Gemini) LLMRegistry.register(Gemma) LLMRegistry.register(ApigeeLlm) + +# Optionally register Claude if anthropic package is installed +try: + from .anthropic_llm import Claude + + LLMRegistry.register(Claude) + __all__.append('Claude') +except Exception: + # Claude support requires: pip install google-adk[extensions] + pass + +# Optionally register LiteLlm if litellm package is installed +try: + from .lite_llm import LiteLlm + + LLMRegistry.register(LiteLlm) + __all__.append('LiteLlm') +except Exception: + # LiteLLM support requires: pip install google-adk[extensions] + pass diff --git a/src/google/adk/models/lite_llm.py b/src/google/adk/models/lite_llm.py index 9e3698b190..162db05945 100644 --- a/src/google/adk/models/lite_llm.py +++ b/src/google/adk/models/lite_llm.py @@ -1388,11 +1388,19 @@ async def generate_content_async( def supported_models(cls) -> list[str]: """Provides the list of supported models. - LiteLlm supports all models supported by litellm. We do not keep track of - these models here. So we return an empty list. + This registers common provider prefixes. LiteLlm can handle many more, + but these patterns activate the integration for the most common use cases. + See https://docs.litellm.ai/docs/providers for a full list. Returns: A list of supported models. """ - return [] + return [ + # For OpenAI models (e.g., "openai/gpt-4o") + r"openai/.*", + # For Groq models via Groq API (e.g., "groq/llama3-70b-8192") + r"groq/.*", + # For Anthropic models (e.g., "anthropic/claude-3-opus-20240229") + r"anthropic/.*", + ] diff --git a/src/google/adk/models/registry.py b/src/google/adk/models/registry.py index 22e24d4c18..852996ff40 100644 --- a/src/google/adk/models/registry.py +++ b/src/google/adk/models/registry.py @@ -99,4 +99,26 @@ def resolve(model: str) -> type[BaseLlm]: if re.compile(regex).fullmatch(model): return llm_class - raise ValueError(f'Model {model} not found.') + # Provide helpful error messages for known patterns + error_msg = f'Model {model} not found.' + + # Check if it matches known patterns that require optional dependencies + if re.match(r'^claude-', model): + error_msg += ( + '\n\nClaude models require the anthropic package.' + '\nInstall it with: pip install google-adk[extensions]' + '\nOr: pip install anthropic>=0.43.0' + ) + elif '/' in model: + # Any model with provider/model format likely needs LiteLLM + error_msg += ( + '\n\nProvider-style models (e.g., "provider/model-name") require' + ' the litellm package.' + '\nInstall it with: pip install google-adk[extensions]' + '\nOr: pip install litellm>=1.75.5' + '\n\nSupported providers include: openai, groq, anthropic, and 100+' + ' others.' + '\nSee https://docs.litellm.ai/docs/providers for a full list.' + ) + + raise ValueError(error_msg) diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py index c57254dbc8..577923f7bf 100644 --- a/tests/unittests/agents/test_llm_agent_fields.py +++ b/tests/unittests/agents/test_llm_agent_fields.py @@ -22,6 +22,9 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.models.anthropic_llm import Claude +from google.adk.models.google_llm import Gemini +from google.adk.models.lite_llm import LiteLlm from google.adk.models.llm_request import LlmRequest from google.adk.models.registry import LLMRegistry from google.adk.sessions.in_memory_session_service import InMemorySessionService @@ -411,3 +414,47 @@ async def test_handle_vais_only(self): assert len(tools) == 1 assert tools[0].name == 'vertex_ai_search' assert tools[0].__class__.__name__ == 'VertexAiSearchTool' + + +# Tests for multi-provider model support via string model names +@pytest.mark.parametrize( + 'model_name', + [ + 'gemini-1.5-flash', + 'gemini-2.0-flash-exp', + ], +) +def test_agent_with_gemini_string_model(model_name): + """Test that Agent accepts Gemini model strings and resolves to Gemini.""" + agent = LlmAgent(name='test_agent', model=model_name) + assert isinstance(agent.canonical_model, Gemini) + assert agent.canonical_model.model == model_name + + +@pytest.mark.parametrize( + 'model_name', + [ + 'claude-3-5-sonnet-v2@20241022', + 'claude-sonnet-4@20250514', + ], +) +def test_agent_with_claude_string_model(model_name): + """Test that Agent accepts Claude model strings and resolves to Claude.""" + agent = LlmAgent(name='test_agent', model=model_name) + assert isinstance(agent.canonical_model, Claude) + assert agent.canonical_model.model == model_name + + +@pytest.mark.parametrize( + 'model_name', + [ + 'openai/gpt-4o', + 'groq/llama3-70b-8192', + 'anthropic/claude-3-opus-20240229', + ], +) +def test_agent_with_litellm_string_model(model_name): + """Test that Agent accepts LiteLLM provider strings.""" + agent = LlmAgent(name='test_agent', model=model_name) + assert isinstance(agent.canonical_model, LiteLlm) + assert agent.canonical_model.model == model_name diff --git a/tests/unittests/models/test_models.py b/tests/unittests/models/test_models.py index 70246c7bc1..8575064baa 100644 --- a/tests/unittests/models/test_models.py +++ b/tests/unittests/models/test_models.py @@ -15,7 +15,7 @@ from google.adk import models from google.adk.models.anthropic_llm import Claude from google.adk.models.google_llm import Gemini -from google.adk.models.registry import LLMRegistry +from google.adk.models.lite_llm import LiteLlm import pytest @@ -34,6 +34,7 @@ ], ) def test_match_gemini_family(model_name): + """Test that Gemini models are resolved correctly.""" assert models.LLMRegistry.resolve(model_name) is Gemini @@ -51,12 +52,63 @@ def test_match_gemini_family(model_name): ], ) def test_match_claude_family(model_name): - LLMRegistry.register(Claude) - + """Test that Claude models are resolved correctly.""" assert models.LLMRegistry.resolve(model_name) is Claude +@pytest.mark.parametrize( + 'model_name', + [ + 'openai/gpt-4o', + 'openai/gpt-4o-mini', + 'groq/llama3-70b-8192', + 'groq/mixtral-8x7b-32768', + 'anthropic/claude-3-opus-20240229', + 'anthropic/claude-3-5-sonnet-20241022', + ], +) +def test_match_litellm_family(model_name): + """Test that LiteLLM models are resolved correctly.""" + assert models.LLMRegistry.resolve(model_name) is LiteLlm + + def test_non_exist_model(): with pytest.raises(ValueError) as e_info: models.LLMRegistry.resolve('non-exist-model') assert 'Model non-exist-model not found.' in str(e_info.value) + + +def test_helpful_error_for_claude_without_extensions(): + """Test that missing Claude models show helpful install instructions. + + Note: This test may pass even when anthropic IS installed, because it + only checks the error message format when a model is not found. + """ + # Use a non-existent Claude model variant to trigger error + with pytest.raises(ValueError) as e_info: + models.LLMRegistry.resolve('claude-nonexistent-model-xyz') + + error_msg = str(e_info.value) + # The error should mention anthropic package and installation instructions + # These checks work whether or not anthropic is actually installed + assert 'Model claude-nonexistent-model-xyz not found' in error_msg + assert 'anthropic package' in error_msg + assert 'pip install' in error_msg + + +def test_helpful_error_for_litellm_without_extensions(): + """Test that missing LiteLLM models show helpful install instructions. + + Note: This test may pass even when litellm IS installed, because it + only checks the error message format when a model is not found. + """ + # Use a non-existent provider to trigger error + with pytest.raises(ValueError) as e_info: + models.LLMRegistry.resolve('unknown-provider/gpt-4o') + + error_msg = str(e_info.value) + # The error should mention litellm package for provider-style models + assert 'Model unknown-provider/gpt-4o not found' in error_msg + assert 'litellm package' in error_msg + assert 'pip install' in error_msg + assert 'Provider-style models' in error_msg From 5cad8a7f58b36ca8ae0e5db2d0a8fb8718d330fd Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Tue, 25 Nov 2025 10:04:23 -0800 Subject: [PATCH 36/63] fix: Throw warning when using transparent session resumption in ADK Live for Gemini API key transparent session resumption is only supported in Vertex AI APIs Co-authored-by: Hangfei Lin PiperOrigin-RevId: 836715170 --- src/google/adk/models/gemini_llm_connection.py | 2 +- src/google/adk/models/google_llm.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 0b72c79f83..15e6ed9599 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -244,7 +244,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: ] yield LlmResponse(content=types.Content(role='model', parts=parts)) if message.session_resumption_update: - logger.info('Received session resumption message: %s', message) + logger.debug('Received session resumption message: %s', message) yield ( LlmResponse( live_session_resumption_update=message.session_resumption_update diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 1bdd311104..90c2fece76 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -325,6 +325,23 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: types.Part.from_text(text=llm_request.config.system_instruction) ], ) + if ( + llm_request.live_connect_config.session_resumption + and llm_request.live_connect_config.session_resumption.transparent + ): + logger.debug( + 'session resumption config: %s', + llm_request.live_connect_config.session_resumption, + ) + logger.debug( + 'self._api_backend: %s', + self._api_backend, + ) + if self._api_backend == GoogleLLMVariant.GEMINI_API: + raise ValueError( + 'Transparent session resumption is only supported for Vertex AI' + ' backend. Please use Vertex AI backend.' + ) llm_request.live_connect_config.tools = llm_request.config.tools logger.info('Connecting to live for model: %s', llm_request.model) logger.debug('Connecting to live with llm_request:%s', llm_request) From 5453b5bfdedc91d9d668c9eac39e3bb009a7bbbf Mon Sep 17 00:00:00 2001 From: happyryan Date: Tue, 25 Nov 2025 10:29:38 -0800 Subject: [PATCH 37/63] fix: Allow image parts in user messages for Anthropic Claude Previously, image parts were always filtered out when converting content to Anthropic message parameters. This change updates the logic to only filter out image parts and log a warning when the content role is not "user". This enables sending image data as part of user prompts to Claude models Merges: https://github.com/google/adk-python/pull/3286 Co-authored-by: George Weale PiperOrigin-RevId: 836725196 --- src/google/adk/models/anthropic_llm.py | 8 +- tests/unittests/models/test_anthropic_llm.py | 79 ++++++++++++++++++++ 2 files changed, 84 insertions(+), 3 deletions(-) diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index 6f343367a3..f965a9906d 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -155,9 +155,11 @@ def content_to_message_param( ) -> anthropic_types.MessageParam: message_block = [] for part in content.parts or []: - # Image data is not supported in Claude for model turns. - if _is_image_part(part): - logger.warning("Image data is not supported in Claude for model turns.") + # Image data is not supported in Claude for assistant turns. + if content.role != "user" and _is_image_part(part): + logger.warning( + "Image data is not supported in Claude for assistant turns." + ) continue message_block.append(part_to_message_block(part)) diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index e5ac8cc051..13d615bc32 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -20,6 +20,7 @@ from google.adk import version as adk_version from google.adk.models import anthropic_llm from google.adk.models.anthropic_llm import Claude +from google.adk.models.anthropic_llm import content_to_message_param from google.adk.models.anthropic_llm import function_declaration_to_tool_param from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse @@ -462,3 +463,81 @@ def test_part_to_message_block_with_multiple_content_items(): assert isinstance(result, dict) # Multiple text items should be joined with newlines assert result["content"] == "First part\nSecond part" + + +content_to_message_param_test_cases = [ + ( + "user_role_with_text_and_image", + Content( + role="user", + parts=[ + Part.from_text(text="What's in this image?"), + Part( + inline_data=types.Blob( + mime_type="image/jpeg", data=b"fake_image_data" + ) + ), + ], + ), + "user", + 2, # Expected content length + False, # Should not log warning + ), + ( + "model_role_with_text_and_image", + Content( + role="model", + parts=[ + Part.from_text(text="I see a cat."), + Part( + inline_data=types.Blob( + mime_type="image/png", data=b"fake_image_data" + ) + ), + ], + ), + "assistant", + 1, # Image filtered out, only text remains + True, # Should log warning + ), + ( + "assistant_role_with_text_and_image", + Content( + role="assistant", + parts=[ + Part.from_text(text="Here's what I found."), + Part( + inline_data=types.Blob( + mime_type="image/webp", data=b"fake_image_data" + ) + ), + ], + ), + "assistant", + 1, # Image filtered out, only text remains + True, # Should log warning + ), +] + + +@pytest.mark.parametrize( + "_, content, expected_role, expected_content_length, should_log_warning", + content_to_message_param_test_cases, + ids=[case[0] for case in content_to_message_param_test_cases], +) +def test_content_to_message_param_with_images( + _, content, expected_role, expected_content_length, should_log_warning +): + """Test content_to_message_param handles images correctly based on role.""" + with mock.patch("google.adk.models.anthropic_llm.logger") as mock_logger: + result = content_to_message_param(content) + + assert result["role"] == expected_role + assert len(result["content"]) == expected_content_length + + if should_log_warning: + mock_logger.warning.assert_called_once_with( + "Image data is not supported in Claude for assistant turns." + ) + else: + mock_logger.warning.assert_not_called() From 06e6fc91327a8bcea1bdc72f8eee94ee05cbbb91 Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 25 Nov 2025 10:47:17 -0800 Subject: [PATCH 38/63] feat: wire runtime entrypoints to service factory defaults This change routes adk run and the FastAPI server through the new session/artifact service factory, keeps the default experience backed by per-agent .adk storage Co-authored-by: George Weale PiperOrigin-RevId: 836733234 --- src/google/adk/cli/cli.py | 79 ++++++--- src/google/adk/cli/fast_api.py | 63 +++---- src/google/adk/cli/utils/service_factory.py | 138 +++++++++++++++ tests/unittests/cli/test_fast_api.py | 12 +- tests/unittests/cli/utils/test_cli.py | 131 ++++++++++++-- .../cli/utils/test_service_factory.py | 162 ++++++++++++++++++ 6 files changed, 508 insertions(+), 77 deletions(-) create mode 100644 src/google/adk/cli/utils/service_factory.py create mode 100644 tests/unittests/cli/utils/test_service_factory.py diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index 5ae18aac0a..af57a687fb 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -15,6 +15,7 @@ from __future__ import annotations from datetime import datetime +from pathlib import Path from typing import Optional from typing import Union @@ -22,7 +23,6 @@ from google.genai import types from pydantic import BaseModel -from ..agents.base_agent import BaseAgent from ..agents.llm_agent import LlmAgent from ..apps.app import App from ..artifacts.base_artifact_service import BaseArtifactService @@ -35,8 +35,11 @@ from ..sessions.session import Session from ..utils.context_utils import Aclosing from ..utils.env_utils import is_env_enabled +from .service_registry import load_services_module from .utils import envs from .utils.agent_loader import AgentLoader +from .utils.service_factory import create_artifact_service_from_options +from .utils.service_factory import create_session_service_from_options class InputFile(BaseModel): @@ -66,7 +69,7 @@ async def run_input_file( ) with open(input_path, 'r', encoding='utf-8') as f: input_file = InputFile.model_validate_json(f.read()) - input_file.state['_time'] = datetime.now() + input_file.state['_time'] = datetime.now().isoformat() session = await session_service.create_session( app_name=app_name, user_id=user_id, state=input_file.state @@ -134,6 +137,8 @@ async def run_cli( saved_session_file: Optional[str] = None, save_session: bool, session_id: Optional[str] = None, + session_service_uri: Optional[str] = None, + artifact_service_uri: Optional[str] = None, ) -> None: """Runs an interactive CLI for a certain agent. @@ -148,24 +153,47 @@ async def run_cli( contains a previously saved session, exclusive with input_file. save_session: bool, whether to save the session on exit. session_id: Optional[str], the session ID to save the session to on exit. + session_service_uri: Optional[str], custom session service URI. + artifact_service_uri: Optional[str], custom artifact service URI. """ + agent_parent_path = Path(agent_parent_dir).resolve() + agent_root = agent_parent_path / agent_folder_name + load_services_module(str(agent_root)) + user_id = 'test_user' - artifact_service = InMemoryArtifactService() - session_service = InMemorySessionService() - credential_service = InMemoryCredentialService() + # Create session and artifact services using factory functions + session_service = create_session_service_from_options( + base_dir=agent_root, + session_service_uri=session_service_uri, + ) - user_id = 'test_user' - agent_or_app = AgentLoader(agents_dir=agent_parent_dir).load_agent( + artifact_service = create_artifact_service_from_options( + base_dir=agent_root, + artifact_service_uri=artifact_service_uri, + ) + + credential_service = InMemoryCredentialService() + agents_dir = str(agent_parent_path) + agent_or_app = AgentLoader(agents_dir=agents_dir).load_agent( agent_folder_name ) session_app_name = ( agent_or_app.name if isinstance(agent_or_app, App) else agent_folder_name ) - session = await session_service.create_session( - app_name=session_app_name, user_id=user_id - ) if not is_env_enabled('ADK_DISABLE_LOAD_DOTENV'): - envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir) + envs.load_dotenv_for_agent(agent_folder_name, agents_dir) + + # Helper function for printing events + def _print_event(event) -> None: + content = event.content + if not content or not content.parts: + return + text_parts = [part.text for part in content.parts if part.text] + if not text_parts: + return + author = event.author or 'system' + click.echo(f'[{author}]: {"".join(text_parts)}') + if input_file: session = await run_input_file( app_name=session_app_name, @@ -177,16 +205,22 @@ async def run_cli( input_path=input_file, ) elif saved_session_file: + # Load the saved session from file with open(saved_session_file, 'r', encoding='utf-8') as f: loaded_session = Session.model_validate_json(f.read()) + # Create a new session in the service, copying state from the file + session = await session_service.create_session( + app_name=session_app_name, + user_id=user_id, + state=loaded_session.state if loaded_session else None, + ) + + # Append events from the file to the new session and display them if loaded_session: for event in loaded_session.events: await session_service.append_event(session, event) - content = event.content - if not content or not content.parts or not content.parts[0].text: - continue - click.echo(f'[{event.author}]: {content.parts[0].text}') + _print_event(event) await run_interactively( agent_or_app, @@ -196,6 +230,9 @@ async def run_cli( credential_service, ) else: + session = await session_service.create_session( + app_name=session_app_name, user_id=user_id + ) click.echo(f'Running agent {agent_or_app.name}, type exit to exit.') await run_interactively( agent_or_app, @@ -207,9 +244,7 @@ async def run_cli( if save_session: session_id = session_id or input('Session ID to save: ') - session_path = ( - f'{agent_parent_dir}/{agent_folder_name}/{session_id}.session.json' - ) + session_path = agent_root / f'{session_id}.session.json' # Fetch the session again to get all the details. session = await session_service.get_session( @@ -217,9 +252,9 @@ async def run_cli( user_id=session.user_id, session_id=session.id, ) - with open(session_path, 'w', encoding='utf-8') as f: - f.write( - session.model_dump_json(indent=2, exclude_none=True, by_alias=True) - ) + session_path.write_text( + session.model_dump_json(indent=2, exclude_none=True, by_alias=True), + encoding='utf-8', + ) print('Session saved to', session_path) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index eec6bb646b..86c7ca55c6 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -34,20 +34,19 @@ from starlette.types import Lifespan from watchdog.observers import Observer -from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager -from ..memory.in_memory_memory_service import InMemoryMemoryService from ..runners import Runner -from ..sessions.in_memory_session_service import InMemorySessionService from .adk_web_server import AdkWebServer -from .service_registry import get_service_registry from .service_registry import load_services_module from .utils import envs from .utils import evals from .utils.agent_change_handler import AgentChangeEventHandler from .utils.agent_loader import AgentLoader +from .utils.service_factory import create_artifact_service_from_options +from .utils.service_factory import create_memory_service_from_options +from .utils.service_factory import create_session_service_from_options logger = logging.getLogger("google_adk." + __name__) @@ -74,6 +73,8 @@ def get_fast_api_app( logo_text: Optional[str] = None, logo_image_url: Optional[str] = None, ) -> FastAPI: + # Convert to absolute path for consistency + agents_dir = str(Path(agents_dir).resolve()) # Set up eval managers. if eval_storage_uri: @@ -91,48 +92,32 @@ def get_fast_api_app( # Load services.py from agents_dir for custom service registration. load_services_module(agents_dir) - service_registry = get_service_registry() - # Build the Memory service - if memory_service_uri: - memory_service = service_registry.create_memory_service( - memory_service_uri, agents_dir=agents_dir + try: + memory_service = create_memory_service_from_options( + base_dir=agents_dir, + memory_service_uri=memory_service_uri, ) - if not memory_service: - raise click.ClickException( - "Unsupported memory service URI: %s" % memory_service_uri - ) - else: - memory_service = InMemoryMemoryService() + except ValueError as exc: + raise click.ClickException(str(exc)) from exc # Build the Session service - if session_service_uri: - session_kwargs = session_db_kwargs or {} - session_service = service_registry.create_session_service( - session_service_uri, agents_dir=agents_dir, **session_kwargs - ) - if not session_service: - # Fallback to DatabaseSessionService if the service registry doesn't - # support the session service URI scheme. - from ..sessions.database_session_service import DatabaseSessionService - - session_service = DatabaseSessionService( - db_url=session_service_uri, **session_kwargs - ) - else: - session_service = InMemorySessionService() + session_service = create_session_service_from_options( + base_dir=agents_dir, + session_service_uri=session_service_uri, + session_db_kwargs=session_db_kwargs, + per_agent=True, # Multi-agent mode + ) # Build the Artifact service - if artifact_service_uri: - artifact_service = service_registry.create_artifact_service( - artifact_service_uri, agents_dir=agents_dir + try: + artifact_service = create_artifact_service_from_options( + base_dir=agents_dir, + artifact_service_uri=artifact_service_uri, + per_agent=True, # Multi-agent mode ) - if not artifact_service: - raise click.ClickException( - "Unsupported artifact service URI: %s" % artifact_service_uri - ) - else: - artifact_service = InMemoryArtifactService() + except ValueError as exc: + raise click.ClickException(str(exc)) from exc # Build the Credential service credential_service = InMemoryCredentialService() diff --git a/src/google/adk/cli/utils/service_factory.py b/src/google/adk/cli/utils/service_factory.py new file mode 100644 index 0000000000..50064f4b8f --- /dev/null +++ b/src/google/adk/cli/utils/service_factory.py @@ -0,0 +1,138 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any +from typing import Optional + +from ...artifacts.base_artifact_service import BaseArtifactService +from ...memory.base_memory_service import BaseMemoryService +from ...sessions.base_session_service import BaseSessionService +from ..service_registry import get_service_registry +from .local_storage import create_local_artifact_service + +logger = logging.getLogger("google_adk." + __name__) + + +def create_session_service_from_options( + *, + base_dir: Path | str, + session_service_uri: Optional[str] = None, + session_db_kwargs: Optional[dict[str, Any]] = None, + per_agent: bool = False, +) -> BaseSessionService: + """Creates a session service based on CLI/web options.""" + base_path = Path(base_dir) + registry = get_service_registry() + + kwargs: dict[str, Any] = { + "agents_dir": str(base_path), + "per_agent": per_agent, + } + if session_db_kwargs: + kwargs.update(session_db_kwargs) + + if session_service_uri: + if per_agent: + logger.warning( + "per_agent is not supported with remote session service URIs," + " ignoring" + ) + logger.info("Using session service URI: %s", session_service_uri) + service = registry.create_session_service(session_service_uri, **kwargs) + if service is not None: + return service + + # Fallback to DatabaseSessionService if the registry doesn't support the + # session service URI scheme. This keeps support for SQLAlchemy-compatible + # databases like AlloyDB or Cloud Spanner without explicit registration. + from ...sessions.database_session_service import DatabaseSessionService + + fallback_kwargs = dict(kwargs) + fallback_kwargs.pop("agents_dir", None) + fallback_kwargs.pop("per_agent", None) + logger.info( + "Falling back to DatabaseSessionService for URI: %s", + session_service_uri, + ) + return DatabaseSessionService(db_url=session_service_uri, **fallback_kwargs) + + logger.info("Using in-memory session service") + from ...sessions.in_memory_session_service import InMemorySessionService + + return InMemorySessionService() + + +def create_memory_service_from_options( + *, + base_dir: Path | str, + memory_service_uri: Optional[str] = None, +) -> BaseMemoryService: + """Creates a memory service based on CLI/web options.""" + base_path = Path(base_dir) + registry = get_service_registry() + + if memory_service_uri: + logger.info("Using memory service URI: %s", memory_service_uri) + service = registry.create_memory_service( + memory_service_uri, + agents_dir=str(base_path), + ) + if service is None: + raise ValueError(f"Unsupported memory service URI: {memory_service_uri}") + return service + + logger.info("Using in-memory memory service") + from ...memory.in_memory_memory_service import InMemoryMemoryService + + return InMemoryMemoryService() + + +def create_artifact_service_from_options( + *, + base_dir: Path | str, + artifact_service_uri: Optional[str] = None, + per_agent: bool = False, +) -> BaseArtifactService: + """Creates an artifact service based on CLI/web options.""" + base_path = Path(base_dir) + registry = get_service_registry() + + if artifact_service_uri: + if per_agent: + logger.warning( + "per_agent is not supported with remote artifact service URIs," + " ignoring" + ) + logger.info("Using artifact service URI: %s", artifact_service_uri) + service = registry.create_artifact_service( + artifact_service_uri, + agents_dir=str(base_path), + per_agent=per_agent, + ) + if service is None: + logger.warning( + "Unsupported artifact service URI: %s, falling back to in-memory", + artifact_service_uri, + ) + from ...artifacts.in_memory_artifact_service import InMemoryArtifactService + + return InMemoryArtifactService() + return service + + if per_agent: + logger.info("Using shared file artifact service rooted at %s", base_dir) + return create_local_artifact_service(base_dir=base_path, per_agent=per_agent) diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index d50bfcd8e5..a8b1ef2f2f 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -327,15 +327,15 @@ def test_app( with ( patch("signal.signal", return_value=None), patch( - "google.adk.cli.fast_api.InMemorySessionService", + "google.adk.cli.fast_api.create_session_service_from_options", return_value=mock_session_service, ), patch( - "google.adk.cli.fast_api.InMemoryArtifactService", + "google.adk.cli.fast_api.create_artifact_service_from_options", return_value=mock_artifact_service, ), patch( - "google.adk.cli.fast_api.InMemoryMemoryService", + "google.adk.cli.fast_api.create_memory_service_from_options", return_value=mock_memory_service, ), patch( @@ -472,15 +472,15 @@ def test_app_with_a2a( with ( patch("signal.signal", return_value=None), patch( - "google.adk.cli.fast_api.InMemorySessionService", + "google.adk.cli.fast_api.create_session_service_from_options", return_value=mock_session_service, ), patch( - "google.adk.cli.fast_api.InMemoryArtifactService", + "google.adk.cli.fast_api.create_artifact_service_from_options", return_value=mock_artifact_service, ), patch( - "google.adk.cli.fast_api.InMemoryMemoryService", + "google.adk.cli.fast_api.create_memory_service_from_options", return_value=mock_memory_service, ), patch( diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index 0de59598b3..33ddbf495c 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -28,7 +28,12 @@ import click from google.adk.agents.base_agent import BaseAgent from google.adk.apps.app import App +from google.adk.artifacts.file_artifact_service import FileArtifactService +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService import google.adk.cli.cli as cli +from google.adk.cli.utils.service_factory import create_artifact_service_from_options +from google.adk.sessions.in_memory_session_service import InMemorySessionService import pytest @@ -151,9 +156,9 @@ def _echo(msg: str) -> None: input_path = tmp_path / "input.json" input_path.write_text(json.dumps(input_json)) - artifact_service = cli.InMemoryArtifactService() - session_service = cli.InMemorySessionService() - credential_service = cli.InMemoryCredentialService() + artifact_service = InMemoryArtifactService() + session_service = InMemorySessionService() + credential_service = InMemoryCredentialService() dummy_root = BaseAgent(name="root") session = await cli.run_input_file( @@ -189,6 +194,34 @@ async def test_run_cli_with_input_file(fake_agent, tmp_path: Path) -> None: ) +@pytest.mark.asyncio +async def test_run_cli_loads_services_module( + fake_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + """run_cli should load custom services from the agents directory.""" + parent_dir, folder_name = fake_agent + input_json = {"state": {}, "queries": ["ping"]} + input_path = tmp_path / "input.json" + input_path.write_text(json.dumps(input_json)) + + loaded_dirs: list[str] = [] + monkeypatch.setattr( + cli, "load_services_module", lambda path: loaded_dirs.append(path) + ) + + agent_root = parent_dir / folder_name + + await cli.run_cli( + agent_parent_dir=str(parent_dir), + agent_folder_name=folder_name, + input_file=str(input_path), + saved_session_file=None, + save_session=False, + ) + + assert loaded_dirs == [str(agent_root.resolve())] + + @pytest.mark.asyncio async def test_run_cli_app_uses_app_name_for_sessions( fake_app_agent, tmp_path: Path, monkeypatch: pytest.MonkeyPatch @@ -197,15 +230,20 @@ async def test_run_cli_app_uses_app_name_for_sessions( parent_dir, folder_name, app_name = fake_app_agent created_app_names: List[str] = [] - original_session_cls = cli.InMemorySessionService - - class _SpySessionService(original_session_cls): + class _SpySessionService(InMemorySessionService): async def create_session(self, *, app_name: str, **kwargs: Any) -> Any: created_app_names.append(app_name) return await super().create_session(app_name=app_name, **kwargs) - monkeypatch.setattr(cli, "InMemorySessionService", _SpySessionService) + spy_session_service = _SpySessionService() + + def _session_factory(**_: Any) -> InMemorySessionService: + return spy_session_service + + monkeypatch.setattr( + cli, "create_session_service_from_options", _session_factory + ) input_json = {"state": {}, "queries": ["ping"]} input_path = tmp_path / "input_app.json" @@ -253,16 +291,89 @@ async def test_run_cli_save_session( assert "id" in data and "events" in data +def test_create_artifact_service_defaults_to_file(tmp_path: Path) -> None: + """Service factory should default to FileArtifactService when URI is unset.""" + service = create_artifact_service_from_options(base_dir=tmp_path) + assert isinstance(service, FileArtifactService) + expected_root = Path(tmp_path) / ".adk" / "artifacts" + assert service.root_dir == expected_root + assert expected_root.exists() + + +def test_create_artifact_service_per_agent_uses_shared_root( + tmp_path: Path, +) -> None: + """Multi-agent mode should still use a single file artifact service.""" + service = create_artifact_service_from_options( + base_dir=tmp_path, per_agent=True + ) + assert isinstance(service, FileArtifactService) + expected_root = Path(tmp_path) / ".adk" / "artifacts" + assert service.root_dir == expected_root + assert expected_root.exists() + + +def test_create_artifact_service_respects_memory_uri(tmp_path: Path) -> None: + """Service factory should honor memory:// URIs.""" + service = create_artifact_service_from_options( + base_dir=tmp_path, artifact_service_uri="memory://" + ) + assert isinstance(service, InMemoryArtifactService) + + +def test_create_artifact_service_accepts_file_uri(tmp_path: Path) -> None: + """Service factory should allow custom local roots via file:// URIs.""" + custom_root = tmp_path / "custom_artifacts" + service = create_artifact_service_from_options( + base_dir=tmp_path, artifact_service_uri=custom_root.as_uri() + ) + assert isinstance(service, FileArtifactService) + assert service.root_dir == custom_root + assert custom_root.exists() + + +def test_create_artifact_service_file_uri_rejects_per_agent(tmp_path: Path): + """file:// URIs are incompatible with per-agent mode.""" + custom_root = tmp_path / "custom" + with pytest.raises(ValueError, match="multi-agent"): + create_artifact_service_from_options( + base_dir=tmp_path, + artifact_service_uri=custom_root.as_uri(), + per_agent=True, + ) + + +@pytest.mark.asyncio +async def test_run_cli_accepts_memory_scheme( + fake_agent, tmp_path: Path +) -> None: + """run_cli should allow configuring in-memory services via memory:// URIs.""" + parent_dir, folder_name = fake_agent + input_json = {"state": {}, "queries": []} + input_path = tmp_path / "noop.json" + input_path.write_text(json.dumps(input_json)) + + await cli.run_cli( + agent_parent_dir=str(parent_dir), + agent_folder_name=folder_name, + input_file=str(input_path), + saved_session_file=None, + save_session=False, + session_service_uri="memory://", + artifact_service_uri="memory://", + ) + + @pytest.mark.asyncio async def test_run_interactively_whitespace_and_exit( tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: """run_interactively should skip blank input, echo once, then exit.""" # make a session that belongs to dummy agent - session_service = cli.InMemorySessionService() + session_service = InMemorySessionService() sess = await session_service.create_session(app_name="dummy", user_id="u") - artifact_service = cli.InMemoryArtifactService() - credential_service = cli.InMemoryCredentialService() + artifact_service = InMemoryArtifactService() + credential_service = InMemoryCredentialService() root_agent = BaseAgent(name="root") # fake user input: blank -> 'hello' -> 'exit' diff --git a/tests/unittests/cli/utils/test_service_factory.py b/tests/unittests/cli/utils/test_service_factory.py new file mode 100644 index 0000000000..5ff92a076b --- /dev/null +++ b/tests/unittests/cli/utils/test_service_factory.py @@ -0,0 +1,162 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for service factory helpers.""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import Mock + +import google.adk.cli.utils.service_factory as service_factory +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.sessions.database_session_service import DatabaseSessionService +from google.adk.sessions.in_memory_session_service import InMemorySessionService +import pytest + + +def test_create_session_service_uses_registry(tmp_path: Path, monkeypatch): + registry = Mock() + expected = object() + registry.create_session_service.return_value = expected + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + result = service_factory.create_session_service_from_options( + base_dir=tmp_path, + session_service_uri="sqlite:///test.db", + ) + + assert result is expected + registry.create_session_service.assert_called_once_with( + "sqlite:///test.db", + agents_dir=str(tmp_path), + per_agent=False, + ) + + +def test_create_session_service_per_agent_uri(tmp_path: Path, monkeypatch): + registry = Mock() + expected = object() + registry.create_session_service.return_value = expected + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + result = service_factory.create_session_service_from_options( + base_dir=tmp_path, + session_service_uri="memory://", + per_agent=True, + ) + + assert result is expected + registry.create_session_service.assert_called_once_with( + "memory://", agents_dir=str(tmp_path), per_agent=True + ) + + +@pytest.mark.parametrize("per_agent", [True, False]) +def test_create_session_service_defaults_to_memory( + tmp_path: Path, per_agent: bool +): + service = service_factory.create_session_service_from_options( + base_dir=tmp_path, + per_agent=per_agent, + ) + + assert isinstance(service, InMemorySessionService) + + +def test_create_session_service_fallbacks_to_database( + tmp_path: Path, monkeypatch +): + registry = Mock() + registry.create_session_service.return_value = None + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + service = service_factory.create_session_service_from_options( + base_dir=tmp_path, + session_service_uri="sqlite+aiosqlite:///:memory:", + session_db_kwargs={"echo": True}, + ) + + assert isinstance(service, DatabaseSessionService) + assert service.db_engine.url.drivername == "sqlite+aiosqlite" + assert service.db_engine.echo is True + registry.create_session_service.assert_called_once_with( + "sqlite+aiosqlite:///:memory:", + agents_dir=str(tmp_path), + per_agent=False, + echo=True, + ) + + +@pytest.mark.parametrize("per_agent", [True, False]) +def test_create_artifact_service_uses_registry( + tmp_path: Path, monkeypatch, per_agent: bool +): + registry = Mock() + expected = object() + registry.create_artifact_service.return_value = expected + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + result = service_factory.create_artifact_service_from_options( + base_dir=tmp_path, + artifact_service_uri="gs://bucket/path", + per_agent=per_agent, + ) + + assert result is expected + registry.create_artifact_service.assert_called_once_with( + "gs://bucket/path", + agents_dir=str(tmp_path), + per_agent=per_agent, + ) + + +def test_create_memory_service_uses_registry(tmp_path: Path, monkeypatch): + registry = Mock() + expected = object() + registry.create_memory_service.return_value = expected + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + result = service_factory.create_memory_service_from_options( + base_dir=tmp_path, + memory_service_uri="rag://my-corpus", + ) + + assert result is expected + registry.create_memory_service.assert_called_once_with( + "rag://my-corpus", + agents_dir=str(tmp_path), + ) + + +def test_create_memory_service_defaults_to_in_memory(tmp_path: Path): + service = service_factory.create_memory_service_from_options( + base_dir=tmp_path + ) + + assert isinstance(service, InMemoryMemoryService) + + +def test_create_memory_service_raises_on_unknown_scheme( + tmp_path: Path, monkeypatch +): + registry = Mock() + registry.create_memory_service.return_value = None + monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) + + with pytest.raises(ValueError): + service_factory.create_memory_service_from_options( + base_dir=tmp_path, + memory_service_uri="unknown://foo", + ) From f283027e9215fc64e4293074dd97584aef3b8c0b Mon Sep 17 00:00:00 2001 From: George Weale Date: Tue, 25 Nov 2025 11:12:34 -0800 Subject: [PATCH 39/63] feat: expose service URI flags Adds the shared adk_services_options decorator to adk run and other commands so developers can pass session/artifact URIs from the CLI Has new warning for the unsupported memory service on adk run, and removes the legacy --session_db_url/--artifact_storage_uri flags with tests Co-authored-by: George Weale PiperOrigin-RevId: 836743358 --- src/google/adk/cli/cli_tools_click.py | 128 ++++++++++-------- src/google/adk/cli/utils/__init__.py | 2 + .../cli/utils/test_cli_tools_click.py | 87 +++++++++--- 3 files changed, 143 insertions(+), 74 deletions(-) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 529ee7319c..c4a13dd15f 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -24,6 +24,7 @@ import os from pathlib import Path import tempfile +import textwrap from typing import Optional import click @@ -354,7 +355,62 @@ def validate_exclusive(ctx, param, value): return value +def adk_services_options(): + """Decorator to add ADK services options to click commands.""" + + def decorator(func): + @click.option( + "--session_service_uri", + help=textwrap.dedent( + """\ + Optional. The URI of the session service. + - Leave unset to use the in-memory session service (default). + - Use 'agentengine://' to connect to Agent Engine + sessions. can either be the full qualified resource + name 'projects/abc/locations/us-central1/reasoningEngines/123' or + the resource id '123'. + - Use 'memory://' to run with the in-memory session service. + - Use 'sqlite://' to connect to a SQLite DB. + - See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls for more details on supported database URIs.""" + ), + ) + @click.option( + "--artifact_service_uri", + type=str, + help=textwrap.dedent( + """\ + Optional. The URI of the artifact service. + - Leave unset to store artifacts under '.adk/artifacts' locally. + - Use 'gs://' to connect to the GCS artifact service. + - Use 'memory://' to force the in-memory artifact service. + - Use 'file://' to store artifacts in a custom local directory.""" + ), + default=None, + ) + @click.option( + "--memory_service_uri", + type=str, + help=textwrap.dedent("""\ + Optional. The URI of the memory service. + - Use 'rag://' to connect to Vertex AI Rag Memory Service. + - Use 'agentengine://' to connect to Agent Engine + sessions. can either be the full qualified resource + name 'projects/abc/locations/us-central1/reasoningEngines/123' or + the resource id '123'. + - Use 'memory://' to force the in-memory memory service."""), + default=None, + ) + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + return decorator + + @main.command("run", cls=HelpfulCommand) +@adk_services_options() @click.option( "--save_session", type=bool, @@ -409,6 +465,9 @@ def cli_run( session_id: Optional[str], replay: Optional[str], resume: Optional[str], + session_service_uri: Optional[str] = None, + artifact_service_uri: Optional[str] = None, + memory_service_uri: Optional[str] = None, ): """Runs an interactive CLI for a certain agent. @@ -420,6 +479,14 @@ def cli_run( """ logs.log_to_tmp_folder() + # Validation warning for memory_service_uri (not supported for adk run) + if memory_service_uri: + click.secho( + "WARNING: --memory_service_uri is not supported for adk run.", + fg="yellow", + err=True, + ) + agent_parent_folder = os.path.dirname(agent) agent_folder_name = os.path.basename(agent) @@ -431,6 +498,8 @@ def cli_run( saved_session_file=resume, save_session=save_session, session_id=session_id, + session_service_uri=session_service_uri, + artifact_service_uri=artifact_service_uri, ) ) @@ -865,55 +934,6 @@ def wrapper(*args, **kwargs): return decorator -def adk_services_options(): - """Decorator to add ADK services options to click commands.""" - - def decorator(func): - @click.option( - "--session_service_uri", - help=( - """Optional. The URI of the session service. - - Use 'agentengine://' to connect to Agent Engine - sessions. can either be the full qualified resource - name 'projects/abc/locations/us-central1/reasoningEngines/123' or - the resource id '123'. - - Use 'sqlite://' to connect to an aio-sqlite - based session service, which is good for local development. - - Use 'postgresql://:@:/' - to connect to a PostgreSQL DB. - - See https://docs.sqlalchemy.org/en/20/core/engines.html#backend-specific-urls - for more details on other database URIs supported by SQLAlchemy.""" - ), - ) - @click.option( - "--artifact_service_uri", - type=str, - help=( - "Optional. The URI of the artifact service," - " supported URIs: gs:// for GCS artifact service." - ), - default=None, - ) - @click.option( - "--memory_service_uri", - type=str, - help=("""Optional. The URI of the memory service. - - Use 'rag://' to connect to Vertex AI Rag Memory Service. - - Use 'agentengine://' to connect to Agent Engine - sessions. can either be the full qualified resource - name 'projects/abc/locations/us-central1/reasoningEngines/123' or - the resource id '123'."""), - default=None, - ) - @functools.wraps(func) - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - - return wrapper - - return decorator - - def deprecated_adk_services_options(): """Deprecated ADK services options.""" @@ -921,7 +941,7 @@ def warn(alternative_param, ctx, param, value): if value: click.echo( click.style( - f"WARNING: Deprecated option {param.name} is used. Please use" + f"WARNING: Deprecated option --{param.name} is used. Please use" f" {alternative_param} instead.", fg="yellow", ), @@ -1116,6 +1136,8 @@ def cli_web( adk web --session_service_uri=[uri] --port=[port] path/to/agents_dir """ + session_service_uri = session_service_uri or session_db_url + artifact_service_uri = artifact_service_uri or artifact_storage_uri logs.setup_adk_logger(getattr(logging, log_level.upper())) @asynccontextmanager @@ -1140,8 +1162,6 @@ async def _lifespan(app: FastAPI): fg="green", ) - session_service_uri = session_service_uri or session_db_url - artifact_service_uri = artifact_service_uri or artifact_storage_uri app = get_fast_api_app( agents_dir=agents_dir, session_service_uri=session_service_uri, @@ -1215,10 +1235,10 @@ def cli_api_server( adk api_server --session_service_uri=[uri] --port=[port] path/to/agents_dir """ - logs.setup_adk_logger(getattr(logging, log_level.upper())) - session_service_uri = session_service_uri or session_db_url artifact_service_uri = artifact_service_uri or artifact_storage_uri + logs.setup_adk_logger(getattr(logging, log_level.upper())) + config = uvicorn.Config( get_fast_api_app( agents_dir=agents_dir, diff --git a/src/google/adk/cli/utils/__init__.py b/src/google/adk/cli/utils/__init__.py index 8aa11b252b..1800f5d04c 100644 --- a/src/google/adk/cli/utils/__init__.py +++ b/src/google/adk/cli/utils/__init__.py @@ -18,8 +18,10 @@ from ...agents.base_agent import BaseAgent from ...agents.llm_agent import LlmAgent +from .dot_adk_folder import DotAdkFolder from .state import create_empty_state __all__ = [ 'create_empty_state', + 'DotAdkFolder', ] diff --git a/tests/unittests/cli/utils/test_cli_tools_click.py b/tests/unittests/cli/utils/test_cli_tools_click.py index be9015ca87..95b561e57b 100644 --- a/tests/unittests/cli/utils/test_cli_tools_click.py +++ b/tests/unittests/cli/utils/test_cli_tools_click.py @@ -76,8 +76,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D401 # Fixtures @pytest.fixture(autouse=True) -def _mute_click(monkeypatch: pytest.MonkeyPatch) -> None: +def _mute_click(request, monkeypatch: pytest.MonkeyPatch) -> None: """Suppress click output during tests.""" + # Allow tests to opt-out of muting by using the 'unmute_click' marker + if "unmute_click" in request.keywords: + return monkeypatch.setattr(click, "echo", lambda *a, **k: None) # Keep secho for error messages # monkeypatch.setattr(click, "secho", lambda *a, **k: None) @@ -121,32 +124,70 @@ def test_cli_create_cmd_invokes_run_cmd( cli_tools_click.main, ["create", "--model", "gemini", "--api_key", "key123", str(app_dir)], ) - assert result.exit_code == 0 + assert result.exit_code == 0, (result.output, repr(result.exception)) assert rec.calls, "cli_create.run_cmd must be called" # cli run -@pytest.mark.asyncio -async def test_cli_run_invokes_run_cli( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch +@pytest.mark.parametrize( + "cli_args,expected_session_uri,expected_artifact_uri", + [ + pytest.param( + [ + "--session_service_uri", + "memory://", + "--artifact_service_uri", + "memory://", + ], + "memory://", + "memory://", + id="memory_scheme_uris", + ), + pytest.param( + [], + None, + None, + id="default_uris_none", + ), + ], +) +def test_cli_run_service_uris( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + cli_args: list, + expected_session_uri: str, + expected_artifact_uri: str, ) -> None: - """`adk run` should call run_cli via asyncio.run with correct parameters.""" - rec = _Recorder() - monkeypatch.setattr(cli_tools_click, "run_cli", lambda **kwargs: rec(kwargs)) - monkeypatch.setattr( - cli_tools_click.asyncio, "run", lambda coro: coro - ) # pass-through - - # create dummy agent directory + """`adk run` should forward service URIs correctly to run_cli.""" agent_dir = tmp_path / "agent" agent_dir.mkdir() (agent_dir / "__init__.py").touch() (agent_dir / "agent.py").touch() + # Capture the coroutine's locals before closing it + captured_locals = [] + + def capture_asyncio_run(coro): + # Extract the locals before closing the coroutine + if coro.cr_frame is not None: + captured_locals.append(dict(coro.cr_frame.f_locals)) + coro.close() # Properly close the coroutine to avoid warnings + + monkeypatch.setattr(cli_tools_click.asyncio, "run", capture_asyncio_run) + runner = CliRunner() - result = runner.invoke(cli_tools_click.main, ["run", str(agent_dir)]) - assert result.exit_code == 0 - assert rec.calls and rec.calls[0][0][0]["agent_folder_name"] == "agent" + result = runner.invoke( + cli_tools_click.main, + ["run", *cli_args, str(agent_dir)], + ) + assert result.exit_code == 0, (result.output, repr(result.exception)) + assert len(captured_locals) == 1, "Expected asyncio.run to be called once" + + # Verify the kwargs passed to run_cli + coro_locals = captured_locals[0] + assert coro_locals.get("session_service_uri") == expected_session_uri + assert coro_locals.get("artifact_service_uri") == expected_artifact_uri + assert coro_locals["agent_folder_name"] == "agent" # cli deploy cloud_run @@ -520,10 +561,13 @@ def test_cli_web_passes_service_uris( assert called_kwargs.get("memory_service_uri") == "rag://mycorpus" -def test_cli_web_passes_deprecated_uris( - tmp_path: Path, monkeypatch: pytest.MonkeyPatch, _patch_uvicorn: _Recorder +@pytest.mark.unmute_click +def test_cli_web_warns_and_maps_deprecated_uris( + tmp_path: Path, + _patch_uvicorn: _Recorder, + monkeypatch: pytest.MonkeyPatch, ) -> None: - """`adk web` should use deprecated URIs if new ones are not provided.""" + """`adk web` should accept deprecated URI flags with warnings.""" agents_dir = tmp_path / "agents" agents_dir.mkdir() @@ -542,11 +586,14 @@ def test_cli_web_passes_deprecated_uris( "gs://deprecated", ], ) + assert result.exit_code == 0 - assert mock_get_app.calls called_kwargs = mock_get_app.calls[0][1] assert called_kwargs.get("session_service_uri") == "sqlite:///deprecated.db" assert called_kwargs.get("artifact_service_uri") == "gs://deprecated" + # Check output for deprecation warnings (CliRunner captures both stdout and stderr) + assert "--session_db_url" in result.output + assert "--artifact_storage_uri" in result.output def test_cli_eval_with_eval_set_file_path( From ec4ccd718feeadeb6b2b59fcc0e9ff29a4fd0bac Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Wed, 26 Nov 2025 10:01:22 -0800 Subject: [PATCH 40/63] feat: Create APIRegistryToolset to add tools from Cloud API registry to agent This calls the cloudapiregistry.googleapis.com API to get MCP tools from the project's registry, and adds them to ADK. Co-authored-by: Kathy Wu PiperOrigin-RevId: 837166909 --- .../samples/api_registry_agent/README.md | 21 ++ .../samples/api_registry_agent/__init__.py | 15 ++ .../samples/api_registry_agent/agent.py | 39 ++++ src/google/adk/tools/__init__.py | 2 + src/google/adk/tools/api_registry.py | 124 +++++++++++ tests/unittests/tools/test_api_registry.py | 205 ++++++++++++++++++ 6 files changed, 406 insertions(+) create mode 100644 contributing/samples/api_registry_agent/README.md create mode 100644 contributing/samples/api_registry_agent/__init__.py create mode 100644 contributing/samples/api_registry_agent/agent.py create mode 100644 src/google/adk/tools/api_registry.py create mode 100644 tests/unittests/tools/test_api_registry.py diff --git a/contributing/samples/api_registry_agent/README.md b/contributing/samples/api_registry_agent/README.md new file mode 100644 index 0000000000..78b3c22382 --- /dev/null +++ b/contributing/samples/api_registry_agent/README.md @@ -0,0 +1,21 @@ +# BigQuery API Registry Agent + +This agent demonstrates how to use `ApiRegistry` to discover and interact with Google Cloud services like BigQuery via tools exposed by an MCP server registered in an API Registry. + +## Prerequisites + +- A Google Cloud project with the API Registry API enabled. +- An MCP server exposing BigQuery tools registered in API Registry. + +## Configuration & Running + +1. **Configure:** Edit `agent.py` and replace `your-google-cloud-project-id` and `your-mcp-server-name` with your Google Cloud Project ID and the name of your registered MCP server. +2. **Run in CLI:** + ```bash + adk run contributing/samples/api_registry_agent -- --log-level DEBUG + ``` +3. **Run in Web UI:** + ```bash + adk web contributing/samples/ + ``` + Navigate to `http://127.0.0.1:8080` and select the `api_registry_agent` agent. diff --git a/contributing/samples/api_registry_agent/__init__.py b/contributing/samples/api_registry_agent/__init__.py new file mode 100644 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/api_registry_agent/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/api_registry_agent/agent.py b/contributing/samples/api_registry_agent/agent.py new file mode 100644 index 0000000000..6504822092 --- /dev/null +++ b/contributing/samples/api_registry_agent/agent.py @@ -0,0 +1,39 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.tools.api_registry import ApiRegistry + +# TODO: Fill in with your GCloud project id and MCP server name +PROJECT_ID = "your-google-cloud-project-id" +MCP_SERVER_NAME = "your-mcp-server-name" + +# Header required for BigQuery MCP server +header_provider = lambda context: { + "x-goog-user-project": PROJECT_ID, +} +api_registry = ApiRegistry(PROJECT_ID, header_provider=header_provider) +registry_tools = api_registry.get_toolset( + mcp_server_name=MCP_SERVER_NAME, +) +root_agent = LlmAgent( + model="gemini-2.0-flash", + name="bigquery_assistant", + instruction=""" +Help user access their BigQuery data via API Registry tools. + """, + tools=[registry_tools], +) diff --git a/src/google/adk/tools/__init__.py b/src/google/adk/tools/__init__.py index d359abb728..32264adcbd 100644 --- a/src/google/adk/tools/__init__.py +++ b/src/google/adk/tools/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from ..auth.auth_tool import AuthToolArguments from .agent_tool import AgentTool + from .api_registry import ApiRegistry from .apihub_tool.apihub_toolset import APIHubToolset from .base_tool import BaseTool from .discovery_engine_search_tool import DiscoveryEngineSearchTool @@ -84,6 +85,7 @@ 'VertexAiSearchTool': ('.vertex_ai_search_tool', 'VertexAiSearchTool'), 'MCPToolset': ('.mcp_tool.mcp_toolset', 'MCPToolset'), 'McpToolset': ('.mcp_tool.mcp_toolset', 'McpToolset'), + 'ApiRegistry': ('.api_registry', 'ApiRegistry'), } __all__ = list(_LAZY_MAPPING.keys()) diff --git a/src/google/adk/tools/api_registry.py b/src/google/adk/tools/api_registry.py new file mode 100644 index 0000000000..941c6f0d5c --- /dev/null +++ b/src/google/adk/tools/api_registry.py @@ -0,0 +1,124 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import sys +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + +import google.auth +import google.auth.transport.requests +import httpx + +from .base_toolset import ToolPredicate +from .mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from .mcp_tool.mcp_toolset import McpToolset + +# TODO(wukathy): Update to prod URL once it is available. +API_REGISTRY_URL = "https://staging-cloudapiregistry.sandbox.googleapis.com" + + +class ApiRegistry: + """Registry that provides McpToolsets for MCP servers registered in API Registry.""" + + def __init__( + self, + api_registry_project_id: str, + location: str = "global", + header_provider: Optional[ + Callable[[ReadonlyContext], Dict[str, str]] + ] = None, + ): + """Initialize the API Registry. + + Args: + api_registry_project_id: The project ID for the Google Cloud API Registry. + location: The location of the API Registry resources. + header_provider: Optional function to provide additional headers for MCP + server calls. + """ + self.api_registry_project_id = api_registry_project_id + self.location = location + self._credentials, _ = google.auth.default() + self._mcp_servers: Dict[str, Dict[str, Any]] = {} + self._header_provider = header_provider + + url = f"{API_REGISTRY_URL}/v1beta/projects/{self.api_registry_project_id}/locations/{self.location}/mcpServers" + try: + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + headers = { + "Authorization": f"Bearer {self._credentials.token}", + "Content-Type": "application/json", + } + with httpx.Client() as client: + response = client.get(url, headers=headers) + response.raise_for_status() + mcp_servers_list = response.json().get("mcpServers", []) + for server in mcp_servers_list: + server_name = server.get("name", "") + if server_name: + self._mcp_servers[server_name] = server + except (httpx.HTTPError, ValueError) as e: + # Handle error in fetching or parsing tool definitions + raise RuntimeError( + f"Error fetching MCP servers from API Registry: {e}" + ) from e + + def get_toolset( + self, + mcp_server_name: str, + tool_filter: Optional[Union[ToolPredicate, List[str]]] = None, + tool_name_prefix: Optional[str] = None, + ) -> McpToolset: + """Return the MCP Toolset based on the params. + + Args: + mcp_server_name: Filter to select the MCP server name to get tools + from. + tool_filter: Optional filter to select specific tools. Can be a list of + tool names or a ToolPredicate function. + tool_name_prefix: Optional prefix to prepend to the names of the tools + returned by the toolset. + + Returns: + McpToolset: A toolset for the MCP server specified. + """ + server = self._mcp_servers.get(mcp_server_name) + if not server: + raise ValueError( + f"MCP server {mcp_server_name} not found in API Registry." + ) + if not server.get("urls"): + raise ValueError(f"MCP server {mcp_server_name} has no URLs.") + + mcp_server_url = server["urls"][0] + request = google.auth.transport.requests.Request() + self._credentials.refresh(request) + headers = { + "Authorization": f"Bearer {self._credentials.token}", + } + return McpToolset( + connection_params=StreamableHTTPConnectionParams( + url="https://" + mcp_server_url, + headers=headers, + ), + tool_filter=tool_filter, + tool_name_prefix=tool_name_prefix, + header_provider=self._header_provider, + ) diff --git a/tests/unittests/tools/test_api_registry.py b/tests/unittests/tools/test_api_registry.py new file mode 100644 index 0000000000..d1131eed0b --- /dev/null +++ b/tests/unittests/tools/test_api_registry.py @@ -0,0 +1,205 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest +from unittest.mock import MagicMock +from unittest.mock import patch + +from google.adk.tools.api_registry import ApiRegistry +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +import httpx + +MOCK_MCP_SERVERS_LIST = { + "mcpServers": [ + { + "name": "test-mcp-server-1", + "urls": ["mcp.server1.com"], + }, + { + "name": "test-mcp-server-2", + "urls": ["mcp.server2.com"], + }, + { + "name": "test-mcp-server-no-url", + }, + ] +} + + +class TestApiRegistry(unittest.IsolatedAsyncioTestCase): + """Unit tests for ApiRegistry.""" + + def setUp(self): + self.project_id = "test-project" + self.location = "global" + self.mock_credentials = MagicMock() + self.mock_credentials.token = "mock_token" + self.mock_credentials.refresh = MagicMock() + mock_auth_patcher = patch( + "google.auth.default", + return_value=(self.mock_credentials, None), + autospec=True, + ) + mock_auth_patcher.start() + self.addCleanup(mock_auth_patcher.stop) + + @patch("httpx.Client", autospec=True) + def test_init_success(self, MockHttpClient): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + self.assertEqual(len(api_registry._mcp_servers), 3) + self.assertIn("test-mcp-server-1", api_registry._mcp_servers) + self.assertIn("test-mcp-server-2", api_registry._mcp_servers) + self.assertIn("test-mcp-server-no-url", api_registry._mcp_servers) + mock_client_instance.get.assert_called_once_with( + f"https://staging-cloudapiregistry.sandbox.googleapis.com/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers", + headers={ + "Authorization": "Bearer mock_token", + "Content-Type": "application/json", + }, + ) + + @patch("httpx.Client", autospec=True) + def test_init_http_error(self, MockHttpClient): + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.side_effect = httpx.RequestError( + "Connection failed" + ) + + with self.assertRaisesRegex(RuntimeError, "Error fetching MCP servers"): + ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + @patch("httpx.Client", autospec=True) + def test_init_bad_response(self, MockHttpClient): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock( + side_effect=httpx.HTTPStatusError( + "Not Found", request=MagicMock(), response=MagicMock() + ) + ) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + with self.assertRaisesRegex(RuntimeError, "Error fetching MCP servers"): + ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + mock_response.raise_for_status.assert_called_once() + + @patch("google.adk.tools.api_registry.McpToolset", autospec=True) + @patch("httpx.Client", autospec=True) + async def test_get_toolset_success(self, MockHttpClient, MockMcpToolset): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + toolset = api_registry.get_toolset("test-mcp-server-1") + + MockMcpToolset.assert_called_once_with( + connection_params=StreamableHTTPConnectionParams( + url="https://mcp.server1.com", + headers={"Authorization": "Bearer mock_token"}, + ), + tool_filter=None, + tool_name_prefix=None, + header_provider=None, + ) + self.assertEqual(toolset, MockMcpToolset.return_value) + + @patch("google.adk.tools.api_registry.McpToolset", autospec=True) + @patch("httpx.Client", autospec=True) + async def test_get_toolset_with_filter_and_prefix( + self, MockHttpClient, MockMcpToolset + ): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + tool_filter = ["tool1"] + tool_name_prefix = "prefix_" + toolset = api_registry.get_toolset( + "test-mcp-server-1", + tool_filter=tool_filter, + tool_name_prefix=tool_name_prefix, + ) + + MockMcpToolset.assert_called_once_with( + connection_params=StreamableHTTPConnectionParams( + url="https://mcp.server1.com", + headers={"Authorization": "Bearer mock_token"}, + ), + tool_filter=tool_filter, + tool_name_prefix=tool_name_prefix, + header_provider=None, + ) + self.assertEqual(toolset, MockMcpToolset.return_value) + + @patch("httpx.Client", autospec=True) + async def test_get_toolset_server_not_found(self, MockHttpClient): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + with self.assertRaisesRegex(ValueError, "not found in API Registry"): + api_registry.get_toolset("non-existent-server") + + @patch("httpx.Client", autospec=True) + async def test_get_toolset_server_no_url(self, MockHttpClient): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json = MagicMock(return_value=MOCK_MCP_SERVERS_LIST) + mock_client_instance = MockHttpClient.return_value + mock_client_instance.__enter__.return_value = mock_client_instance + mock_client_instance.get.return_value = mock_response + + api_registry = ApiRegistry( + api_registry_project_id=self.project_id, location=self.location + ) + + with self.assertRaisesRegex(ValueError, "has no URLs"): + api_registry.get_toolset("test-mcp-server-no-url") From 73e5687b9a2014586c0a5d281c6daed2d4e1186f Mon Sep 17 00:00:00 2001 From: George Weale Date: Wed, 26 Nov 2025 10:07:05 -0800 Subject: [PATCH 41/63] fix: Remove 'per_agent' from kwargs when using remote session service URIs Co-authored-by: George Weale PiperOrigin-RevId: 837169299 --- src/google/adk/cli/fast_api.py | 2 -- src/google/adk/cli/service_registry.py | 8 ++--- src/google/adk/cli/utils/local_storage.py | 11 ++----- src/google/adk/cli/utils/service_factory.py | 19 +---------- tests/unittests/cli/utils/test_cli.py | 19 ++--------- .../cli/utils/test_service_factory.py | 33 ++----------------- 6 files changed, 10 insertions(+), 82 deletions(-) diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index 86c7ca55c6..c095b03a30 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -106,7 +106,6 @@ def get_fast_api_app( base_dir=agents_dir, session_service_uri=session_service_uri, session_db_kwargs=session_db_kwargs, - per_agent=True, # Multi-agent mode ) # Build the Artifact service @@ -114,7 +113,6 @@ def get_fast_api_app( artifact_service = create_artifact_service_from_options( base_dir=agents_dir, artifact_service_uri=artifact_service_uri, - per_agent=True, # Multi-agent mode ) except ValueError as exc: raise click.ClickException(str(exc)) from exc diff --git a/src/google/adk/cli/service_registry.py b/src/google/adk/cli/service_registry.py index 521a3f9f7d..3e7921e075 100644 --- a/src/google/adk/cli/service_registry.py +++ b/src/google/adk/cli/service_registry.py @@ -271,18 +271,14 @@ def gcs_artifact_factory(uri: str, **kwargs): kwargs_copy = kwargs.copy() kwargs_copy.pop("agents_dir", None) + kwargs_copy.pop("per_agent", None) parsed_uri = urlparse(uri) bucket_name = parsed_uri.netloc return GcsArtifactService(bucket_name=bucket_name, **kwargs_copy) - def file_artifact_factory(uri: str, **kwargs): + def file_artifact_factory(uri: str, **_): from ..artifacts.file_artifact_service import FileArtifactService - per_agent = kwargs.get("per_agent", False) - if per_agent: - raise ValueError( - "file:// artifact URIs are not supported in multi-agent mode." - ) parsed_uri = urlparse(uri) if parsed_uri.netloc not in ("", "localhost"): raise ValueError( diff --git a/src/google/adk/cli/utils/local_storage.py b/src/google/adk/cli/utils/local_storage.py index 9e6b3f3d54..b170d66531 100644 --- a/src/google/adk/cli/utils/local_storage.py +++ b/src/google/adk/cli/utils/local_storage.py @@ -58,13 +58,12 @@ def create_local_database_session_service( def create_local_artifact_service( - *, base_dir: Path | str, per_agent: bool = False + *, base_dir: Path | str ) -> BaseArtifactService: """Creates a file-backed artifact service rooted in `.adk/artifacts`. Args: base_dir: Directory whose `.adk` folder will store artifacts. - per_agent: Indicates whether the service is being used in multi-agent mode. Returns: A `FileArtifactService` scoped to the derived root directory. @@ -72,13 +71,7 @@ def create_local_artifact_service( manager = DotAdkFolder(base_dir) artifact_root = manager.artifacts_dir artifact_root.mkdir(parents=True, exist_ok=True) - if per_agent: - logger.info( - "Using shared file artifact service rooted at %s for multi-agent mode", - artifact_root, - ) - else: - logger.info("Using file artifact service at %s", artifact_root) + logger.info("Using file artifact service at %s", artifact_root) return FileArtifactService(root_dir=artifact_root) diff --git a/src/google/adk/cli/utils/service_factory.py b/src/google/adk/cli/utils/service_factory.py index 50064f4b8f..fc2a642c4f 100644 --- a/src/google/adk/cli/utils/service_factory.py +++ b/src/google/adk/cli/utils/service_factory.py @@ -32,7 +32,6 @@ def create_session_service_from_options( base_dir: Path | str, session_service_uri: Optional[str] = None, session_db_kwargs: Optional[dict[str, Any]] = None, - per_agent: bool = False, ) -> BaseSessionService: """Creates a session service based on CLI/web options.""" base_path = Path(base_dir) @@ -40,17 +39,11 @@ def create_session_service_from_options( kwargs: dict[str, Any] = { "agents_dir": str(base_path), - "per_agent": per_agent, } if session_db_kwargs: kwargs.update(session_db_kwargs) if session_service_uri: - if per_agent: - logger.warning( - "per_agent is not supported with remote session service URIs," - " ignoring" - ) logger.info("Using session service URI: %s", session_service_uri) service = registry.create_session_service(session_service_uri, **kwargs) if service is not None: @@ -63,7 +56,6 @@ def create_session_service_from_options( fallback_kwargs = dict(kwargs) fallback_kwargs.pop("agents_dir", None) - fallback_kwargs.pop("per_agent", None) logger.info( "Falling back to DatabaseSessionService for URI: %s", session_service_uri, @@ -105,23 +97,16 @@ def create_artifact_service_from_options( *, base_dir: Path | str, artifact_service_uri: Optional[str] = None, - per_agent: bool = False, ) -> BaseArtifactService: """Creates an artifact service based on CLI/web options.""" base_path = Path(base_dir) registry = get_service_registry() if artifact_service_uri: - if per_agent: - logger.warning( - "per_agent is not supported with remote artifact service URIs," - " ignoring" - ) logger.info("Using artifact service URI: %s", artifact_service_uri) service = registry.create_artifact_service( artifact_service_uri, agents_dir=str(base_path), - per_agent=per_agent, ) if service is None: logger.warning( @@ -133,6 +118,4 @@ def create_artifact_service_from_options( return InMemoryArtifactService() return service - if per_agent: - logger.info("Using shared file artifact service rooted at %s", base_dir) - return create_local_artifact_service(base_dir=base_path, per_agent=per_agent) + return create_local_artifact_service(base_dir=base_path) diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index 33ddbf495c..73ae89a986 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -300,13 +300,11 @@ def test_create_artifact_service_defaults_to_file(tmp_path: Path) -> None: assert expected_root.exists() -def test_create_artifact_service_per_agent_uses_shared_root( +def test_create_artifact_service_uses_shared_root( tmp_path: Path, ) -> None: - """Multi-agent mode should still use a single file artifact service.""" - service = create_artifact_service_from_options( - base_dir=tmp_path, per_agent=True - ) + """Artifact service should use a single file artifact service.""" + service = create_artifact_service_from_options(base_dir=tmp_path) assert isinstance(service, FileArtifactService) expected_root = Path(tmp_path) / ".adk" / "artifacts" assert service.root_dir == expected_root @@ -332,17 +330,6 @@ def test_create_artifact_service_accepts_file_uri(tmp_path: Path) -> None: assert custom_root.exists() -def test_create_artifact_service_file_uri_rejects_per_agent(tmp_path: Path): - """file:// URIs are incompatible with per-agent mode.""" - custom_root = tmp_path / "custom" - with pytest.raises(ValueError, match="multi-agent"): - create_artifact_service_from_options( - base_dir=tmp_path, - artifact_service_uri=custom_root.as_uri(), - per_agent=True, - ) - - @pytest.mark.asyncio async def test_run_cli_accepts_memory_scheme( fake_agent, tmp_path: Path diff --git a/tests/unittests/cli/utils/test_service_factory.py b/tests/unittests/cli/utils/test_service_factory.py index 5ff92a076b..207c96642a 100644 --- a/tests/unittests/cli/utils/test_service_factory.py +++ b/tests/unittests/cli/utils/test_service_factory.py @@ -41,35 +41,12 @@ def test_create_session_service_uses_registry(tmp_path: Path, monkeypatch): registry.create_session_service.assert_called_once_with( "sqlite:///test.db", agents_dir=str(tmp_path), - per_agent=False, ) -def test_create_session_service_per_agent_uri(tmp_path: Path, monkeypatch): - registry = Mock() - expected = object() - registry.create_session_service.return_value = expected - monkeypatch.setattr(service_factory, "get_service_registry", lambda: registry) - - result = service_factory.create_session_service_from_options( - base_dir=tmp_path, - session_service_uri="memory://", - per_agent=True, - ) - - assert result is expected - registry.create_session_service.assert_called_once_with( - "memory://", agents_dir=str(tmp_path), per_agent=True - ) - - -@pytest.mark.parametrize("per_agent", [True, False]) -def test_create_session_service_defaults_to_memory( - tmp_path: Path, per_agent: bool -): +def test_create_session_service_defaults_to_memory(tmp_path: Path): service = service_factory.create_session_service_from_options( base_dir=tmp_path, - per_agent=per_agent, ) assert isinstance(service, InMemorySessionService) @@ -94,15 +71,11 @@ def test_create_session_service_fallbacks_to_database( registry.create_session_service.assert_called_once_with( "sqlite+aiosqlite:///:memory:", agents_dir=str(tmp_path), - per_agent=False, echo=True, ) -@pytest.mark.parametrize("per_agent", [True, False]) -def test_create_artifact_service_uses_registry( - tmp_path: Path, monkeypatch, per_agent: bool -): +def test_create_artifact_service_uses_registry(tmp_path: Path, monkeypatch): registry = Mock() expected = object() registry.create_artifact_service.return_value = expected @@ -111,14 +84,12 @@ def test_create_artifact_service_uses_registry( result = service_factory.create_artifact_service_from_options( base_dir=tmp_path, artifact_service_uri="gs://bucket/path", - per_agent=per_agent, ) assert result is expected registry.create_artifact_service.assert_called_once_with( "gs://bucket/path", agents_dir=str(tmp_path), - per_agent=per_agent, ) From 786aaed335e1ce64b7e92dff2f4af8316b2ef593 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Wed, 26 Nov 2025 10:14:25 -0800 Subject: [PATCH 42/63] feat: Support streaming function call arguments in progressive SSE streaming feature Co-authored-by: Xuan Yang PiperOrigin-RevId: 837172244 --- .../hello_world_stream_fc_args/__init__.py | 15 ++ .../hello_world_stream_fc_args/agent.py | 55 ++++ pyproject.toml | 2 +- src/google/adk/utils/streaming_utils.py | 181 ++++++++++++- .../test_progressive_sse_streaming.py | 242 ++++++++++++++++++ 5 files changed, 492 insertions(+), 3 deletions(-) create mode 100755 contributing/samples/hello_world_stream_fc_args/__init__.py create mode 100755 contributing/samples/hello_world_stream_fc_args/agent.py diff --git a/contributing/samples/hello_world_stream_fc_args/__init__.py b/contributing/samples/hello_world_stream_fc_args/__init__.py new file mode 100755 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/hello_world_stream_fc_args/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/hello_world_stream_fc_args/agent.py b/contributing/samples/hello_world_stream_fc_args/agent.py new file mode 100755 index 0000000000..f613842171 --- /dev/null +++ b/contributing/samples/hello_world_stream_fc_args/agent.py @@ -0,0 +1,55 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk import Agent +from google.genai import types + + +def concat_number_and_string(num: int, s: str) -> str: + """Concatenate a number and a string. + + Args: + num: The number to concatenate. + s: The string to concatenate. + + Returns: + The concatenated string. + """ + return str(num) + ': ' + s + + +root_agent = Agent( + model='gemini-3-pro-preview', + name='hello_world_stream_fc_args', + description='Demo agent showcasing streaming function call arguments.', + instruction=""" + You are a helpful assistant. + You can use the `concat_number_and_string` tool to concatenate a number and a string. + You should always call the concat_number_and_string tool to concatenate a number and a string. + You should never concatenate on your own. + """, + tools=[ + concat_number_and_string, + ], + generate_content_config=types.GenerateContentConfig( + automatic_function_calling=types.AutomaticFunctionCallingConfig( + disable=True, + ), + tool_config=types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + stream_function_call_arguments=True, + ), + ), + ), +) diff --git a/pyproject.toml b/pyproject.toml index 5c0515d6b5..06ddb04ef2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database "google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription "google-cloud-storage>=2.18.0, <4.0.0", # For GCS Artifact service - "google-genai>=1.45.0, <2.0.0", # Google GenAI SDK + "google-genai>=1.51.0, <2.0.0", # Google GenAI SDK "graphviz>=0.20.2, <1.0.0", # Graphviz for graph rendering "jsonschema>=4.23.0, <5.0.0", # Agent Builder config validation "mcp>=1.10.0, <2.0.0", # For MCP Toolset diff --git a/src/google/adk/utils/streaming_utils.py b/src/google/adk/utils/streaming_utils.py index eb75365467..eae80aa7cc 100644 --- a/src/google/adk/utils/streaming_utils.py +++ b/src/google/adk/utils/streaming_utils.py @@ -14,6 +14,7 @@ from __future__ import annotations +from typing import Any from typing import AsyncGenerator from typing import Optional @@ -43,6 +44,12 @@ def __init__(self): self._current_text_is_thought: Optional[bool] = None self._finish_reason: Optional[types.FinishReason] = None + # For streaming function call arguments + self._current_fc_name: Optional[str] = None + self._current_fc_args: dict[str, Any] = {} + self._current_fc_id: Optional[str] = None + self._current_thought_signature: Optional[str] = None + def _flush_text_buffer_to_sequence(self): """Flush current text buffer to parts sequence. @@ -61,6 +68,171 @@ def _flush_text_buffer_to_sequence(self): self._current_text_buffer = '' self._current_text_is_thought = None + def _get_value_from_partial_arg( + self, partial_arg: types.PartialArg, json_path: str + ): + """Extract value from a partial argument. + + Args: + partial_arg: The partial argument object + json_path: JSONPath for this argument + + Returns: + Tuple of (value, has_value) where has_value indicates if a value exists + """ + value = None + has_value = False + + if partial_arg.string_value is not None: + # For streaming strings, append chunks to existing value + string_chunk = partial_arg.string_value + has_value = True + + # Get current value for this path (if any) + path_without_prefix = ( + json_path[2:] if json_path.startswith('$.') else json_path + ) + path_parts = path_without_prefix.split('.') + + # Try to get existing value + existing_value = self._current_fc_args + for part in path_parts: + if isinstance(existing_value, dict) and part in existing_value: + existing_value = existing_value[part] + else: + existing_value = None + break + + # Append to existing string or set new value + if isinstance(existing_value, str): + value = existing_value + string_chunk + else: + value = string_chunk + + elif partial_arg.number_value is not None: + value = partial_arg.number_value + has_value = True + elif partial_arg.bool_value is not None: + value = partial_arg.bool_value + has_value = True + elif partial_arg.null_value is not None: + value = None + has_value = True + + return value, has_value + + def _set_value_by_json_path(self, json_path: str, value: Any): + """Set a value in _current_fc_args using JSONPath notation. + + Args: + json_path: JSONPath string like "$.location" or "$.location.latitude" + value: The value to set + """ + # Remove leading "$." from jsonPath + if json_path.startswith('$.'): + path = json_path[2:] + else: + path = json_path + + # Split path into components + path_parts = path.split('.') + + # Navigate to the correct location and set the value + current = self._current_fc_args + for part in path_parts[:-1]: + if part not in current: + current[part] = {} + current = current[part] + + # Set the final value + current[path_parts[-1]] = value + + def _flush_function_call_to_sequence(self): + """Flush current function call to parts sequence. + + This creates a complete FunctionCall part from accumulated partial args. + """ + if self._current_fc_name: + # Create function call part with accumulated args + fc_part = types.Part.from_function_call( + name=self._current_fc_name, + args=self._current_fc_args.copy(), + ) + + # Set the ID if provided (directly on the function_call object) + if self._current_fc_id and fc_part.function_call: + fc_part.function_call.id = self._current_fc_id + + # Set thought_signature if provided (on the Part, not FunctionCall) + if self._current_thought_signature: + fc_part.thought_signature = self._current_thought_signature + + self._parts_sequence.append(fc_part) + + # Reset FC state + self._current_fc_name = None + self._current_fc_args = {} + self._current_fc_id = None + self._current_thought_signature = None + + def _process_streaming_function_call(self, fc: types.FunctionCall): + """Process a streaming function call with partialArgs. + + Args: + fc: The function call object with partial_args + """ + # Save function name if present (first chunk) + if fc.name: + self._current_fc_name = fc.name + if fc.id: + self._current_fc_id = fc.id + + # Process each partial argument + for partial_arg in getattr(fc, 'partial_args', []): + json_path = partial_arg.json_path + if not json_path: + continue + + # Extract value from partial arg + value, has_value = self._get_value_from_partial_arg( + partial_arg, json_path + ) + + # Set the value using JSONPath (only if a value was provided) + if has_value: + self._set_value_by_json_path(json_path, value) + + # Check if function call is complete + fc_will_continue = getattr(fc, 'will_continue', False) + if not fc_will_continue: + # Function call complete, flush it + self._flush_text_buffer_to_sequence() + self._flush_function_call_to_sequence() + + def _process_function_call_part(self, part: types.Part): + """Process a function call part (streaming or non-streaming). + + Args: + part: The part containing a function call + """ + fc = part.function_call + + # Check if this is a streaming FC (has partialArgs) + if hasattr(fc, 'partial_args') and fc.partial_args: + # Streaming function call arguments + + # Save thought_signature from the part (first chunk should have it) + if part.thought_signature and not self._current_thought_signature: + self._current_thought_signature = part.thought_signature + self._process_streaming_function_call(fc) + else: + # Non-streaming function call (standard format with args) + # Skip empty function calls (used as streaming end markers) + if fc.name: + # Flush any buffered text first, then add the FC part + self._flush_text_buffer_to_sequence() + self._parts_sequence.append(part) + async def process_response( self, response: types.GenerateContentResponse ) -> AsyncGenerator[LlmResponse, None]: @@ -101,8 +273,12 @@ async def process_response( if not self._current_text_buffer: self._current_text_is_thought = part.thought self._current_text_buffer += part.text + elif part.function_call: + # Process function call (handles both streaming Args and + # non-streaming Args) + self._process_function_call_part(part) else: - # Non-text part (function_call, bytes, etc.) + # Other non-text parts (bytes, etc.) # Flush any buffered text first, then add the non-text part self._flush_text_buffer_to_sequence() self._parts_sequence.append(part) @@ -155,8 +331,9 @@ def close(self) -> Optional[LlmResponse]: if is_feature_enabled(FeatureName.PROGRESSIVE_SSE_STREAMING): # Always generate final aggregated response in progressive mode if self._response and self._response.candidates: - # Flush any remaining text buffer to complete the sequence + # Flush any remaining buffers to complete the sequence self._flush_text_buffer_to_sequence() + self._flush_function_call_to_sequence() # Use the parts sequence which preserves original ordering final_parts = self._parts_sequence diff --git a/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py b/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py index e64613ff8d..e589d51c7d 100644 --- a/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py +++ b/tests/unittests/flows/llm_flows/test_progressive_sse_streaming.py @@ -397,3 +397,245 @@ def test_progressive_sse_preserves_part_ordering(): # Part 4: Second function call (from chunk8) assert parts[4].function_call.name == "get_weather" assert parts[4].function_call.args["location"] == "New York" + + +def test_progressive_sse_streaming_function_call_arguments(): + """Test streaming function call arguments feature. + + This test simulates the streamFunctionCallArguments feature where a function + call's arguments are streamed incrementally across multiple chunks: + + Chunk 1: FC name + partial location argument ("New ") + Chunk 2: Continue location argument ("York") -> concatenated to "New York" + Chunk 3: Add unit argument ("celsius"), willContinue=False -> FC complete + + Expected result: FunctionCall(name="get_weather", + args={"location": "New York", "unit": + "celsius"}, + id="fc_001") + """ + + aggregator = StreamingResponseAggregator() + + # Chunk 1: FC name + partial location argument + chunk1_fc = types.FunctionCall( + name="get_weather", + id="fc_001", + partial_args=[ + types.PartialArg(json_path="$.location", string_value="New ") + ], + will_continue=True, + ) + chunk1 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk1_fc)] + ) + ) + ] + ) + + # Chunk 2: Continue streaming location argument + chunk2_fc = types.FunctionCall( + partial_args=[ + types.PartialArg(json_path="$.location", string_value="York") + ], + will_continue=True, + ) + chunk2 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk2_fc)] + ) + ) + ] + ) + + # Chunk 3: Add unit argument, FC complete + chunk3_fc = types.FunctionCall( + partial_args=[ + types.PartialArg(json_path="$.unit", string_value="celsius") + ], + will_continue=False, # FC complete + ) + chunk3 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk3_fc)] + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + + # Process all chunks through aggregator + processed_chunks = [] + for chunk in [chunk1, chunk2, chunk3]: + + async def process(): + results = [] + async for response in aggregator.process_response(chunk): + results.append(response) + return results + + import asyncio + + chunk_results = asyncio.run(process()) + processed_chunks.extend(chunk_results) + + # Get final aggregated response + final_response = aggregator.close() + + # Verify final aggregated response has complete FC + assert final_response is not None + assert len(final_response.content.parts) == 1 + + fc_part = final_response.content.parts[0] + assert fc_part.function_call is not None + assert fc_part.function_call.name == "get_weather" + assert fc_part.function_call.id == "fc_001" + + # Verify arguments were correctly assembled from streaming chunks + args = fc_part.function_call.args + assert args["location"] == "New York" # "New " + "York" concatenated + assert args["unit"] == "celsius" + + +def test_progressive_sse_preserves_thought_signature(): + """Test that thought_signature is preserved when streaming FC arguments. + + This test verifies that when a streaming function call has a thought_signature + in the Part, it is correctly preserved in the final aggregated FunctionCall. + """ + + aggregator = StreamingResponseAggregator() + + # Create a thought signature (simulating what Gemini returns) + # thought_signature is bytes (base64 encoded) + test_thought_signature = b"test_signature_abc123" + + # Chunk with streaming FC args and thought_signature + chunk_fc = types.FunctionCall( + name="add_5_numbers", + id="fc_003", + partial_args=[ + types.PartialArg(json_path="$.num1", number_value=10), + types.PartialArg(json_path="$.num2", number_value=20), + ], + will_continue=False, + ) + + # Create Part with both function_call AND thought_signature + chunk_part = types.Part( + function_call=chunk_fc, thought_signature=test_thought_signature + ) + + chunk = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content(role="model", parts=[chunk_part]), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + + # Process chunk through aggregator + async def process(): + results = [] + async for response in aggregator.process_response(chunk): + results.append(response) + return results + + import asyncio + + asyncio.run(process()) + + # Get final aggregated response + final_response = aggregator.close() + + # Verify thought_signature was preserved in the Part + assert final_response is not None + assert len(final_response.content.parts) == 1 + + fc_part = final_response.content.parts[0] + assert fc_part.function_call is not None + assert fc_part.function_call.name == "add_5_numbers" + + assert fc_part.thought_signature == test_thought_signature + + +def test_progressive_sse_handles_empty_function_call(): + """Test that empty function calls are skipped. + + When using streamFunctionCallArguments, Gemini may send an empty + functionCall: {} as the final chunk to signal streaming completion. + This test verifies that such empty function calls are properly skipped + and don't cause errors. + """ + + aggregator = StreamingResponseAggregator() + + # Chunk 1: Streaming FC with partial args + chunk1_fc = types.FunctionCall( + name="concat_number_and_string", + id="fc_001", + partial_args=[ + types.PartialArg(json_path="$.num", number_value=100), + types.PartialArg(json_path="$.s", string_value="ADK"), + ], + will_continue=False, + ) + chunk1 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk1_fc)] + ) + ) + ] + ) + + # Chunk 2: Empty function call (streaming end marker) + chunk2_fc = types.FunctionCall() # Empty function call + chunk2 = types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role="model", parts=[types.Part(function_call=chunk2_fc)] + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + + # Process all chunks through aggregator + async def process(): + results = [] + for chunk in [chunk1, chunk2]: + async for response in aggregator.process_response(chunk): + results.append(response) + return results + + import asyncio + + asyncio.run(process()) + + # Get final aggregated response + final_response = aggregator.close() + + # Verify final response only has the real FC, not the empty one + assert final_response is not None + assert len(final_response.content.parts) == 1 + + fc_part = final_response.content.parts[0] + assert fc_part.function_call is not None + assert fc_part.function_call.name == "concat_number_and_string" + assert fc_part.function_call.id == "fc_001" + + # Verify arguments + args = fc_part.function_call.args + assert args["num"] == 100 + assert args["s"] == "ADK" From 0094eea3cadf5fe2e960cc558e467dd2131de1b7 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Wed, 26 Nov 2025 19:48:34 -0800 Subject: [PATCH 43/63] feat!: Migrate DatabaseSessionService to use JSON serialization schema Also provide a command line tool `adk migrate session` for DB migration Addresses https://github.com/google/adk-python/discussions/3605 Addresses https://github.com/google/adk-python/issues/3681 To verify: ``` # Start one postgres DB docker run --name my-postgres -d -e POSTGRES_DB=agent -e POSTGRES_USER=agent -e POSTGRES_PASSWORD=agent -e PGDATA=/var/lib/postgresql/data/pgdata -v pgvolume:/var/lib/postgresql/data -p 5532:5432 postgres # Connect to an old version of ADK and produce some query data adk web --session_service_uri=postgresql://agent:agent@localhost:5532/agent # Check out to the latest branch and restart ADK web # You should see error log ask you to migrate the DB # Start a new DB docker run --name migration-test-db \ -d \ --rm \ -e POSTGRES_DB=agent \ -e POSTGRES_USER=agent \ -e POSTGRES_PASSWORD=agent -e PGDATA=/var/lib/postgresql/data/pgdata -v migration_test_vol:/var/lib/postgresql/data -p 5533:5432 postgres # DB Migration adk migrate session \ --source_db_url="postgresql://agent:agent@localhost:5532/agent" \ --dest_db_url="postgresql://agent:agent@localhost:5533/agent" # Run ADK web with the new DB adk web --session_service_uri=postgresql+asyncpg://agent:agent@localhost:5533/agent # You should see the data from old DB is migrated ``` Co-authored-by: Shangjie Chen PiperOrigin-RevId: 837341139 --- src/google/adk/cli/cli_tools_click.py | 36 ++ .../adk/sessions/database_session_service.py | 222 +++----- .../adk/sessions/migration/_schema_check.py | 114 ++++ .../migrate_from_sqlalchemy_pickle.py | 492 ++++++++++++++++++ .../migrate_from_sqlalchemy_sqlite.py | 0 .../sessions/migration/migration_runner.py | 128 +++++ .../adk/sessions/sqlite_session_service.py | 2 +- .../sessions/migration/test_migrations.py | 106 ++++ .../sessions/test_dynamic_pickle_type.py | 181 ------- 9 files changed, 939 insertions(+), 342 deletions(-) create mode 100644 src/google/adk/sessions/migration/_schema_check.py create mode 100644 src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py rename src/google/adk/sessions/{ => migration}/migrate_from_sqlalchemy_sqlite.py (100%) create mode 100644 src/google/adk/sessions/migration/migration_runner.py create mode 100644 tests/unittests/sessions/migration/test_migrations.py delete mode 100644 tests/unittests/sessions/test_dynamic_pickle_type.py diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index c4a13dd15f..e519427259 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -36,6 +36,7 @@ from . import cli_deploy from .. import version from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE +from ..sessions.migration import migration_runner from .cli import run_cli from .fast_api import get_fast_api_app from .utils import envs @@ -1485,6 +1486,41 @@ def cli_deploy_cloud_run( click.secho(f"Deploy failed: {e}", fg="red", err=True) +@main.group() +def migrate(): + """Migrate ADK database schemas.""" + pass + + +@migrate.command("session", cls=HelpfulCommand) +@click.option( + "--source_db_url", + required=True, + help="SQLAlchemy URL of source database.", +) +@click.option( + "--dest_db_url", + required=True, + help="SQLAlchemy URL of destination database.", +) +@click.option( + "--log_level", + type=LOG_LEVELS, + default="INFO", + help="Optional. Set the logging level", +) +def cli_migrate_session( + *, source_db_url: str, dest_db_url: str, log_level: str +): + """Migrates a session database to the latest schema version.""" + logs.setup_adk_logger(getattr(logging, log_level.upper())) + try: + migration_runner.upgrade(source_db_url, dest_db_url) + click.secho("Migration check and upgrade process finished.", fg="green") + except Exception as e: + click.secho(f"Migration failed: {e}", fg="red", err=True) + + @deploy.command("agent_engine") @click.option( "--api_key", diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index a352918211..1576151f23 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -19,18 +19,16 @@ from datetime import timezone import json import logging -import pickle from typing import Any from typing import Optional import uuid -from google.genai import types -from sqlalchemy import Boolean from sqlalchemy import delete from sqlalchemy import Dialect from sqlalchemy import event from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func +from sqlalchemy import inspect from sqlalchemy import select from sqlalchemy import Text from sqlalchemy.dialects import mysql @@ -41,14 +39,11 @@ from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.inspection import inspect from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship -from sqlalchemy.schema import MetaData from sqlalchemy.types import DateTime -from sqlalchemy.types import PickleType from sqlalchemy.types import String from sqlalchemy.types import TypeDecorator from typing_extensions import override @@ -57,10 +52,10 @@ from . import _session_util from ..errors.already_exists_error import AlreadyExistsError from ..events.event import Event -from ..events.event_actions import EventActions from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig from .base_session_service import ListSessionsResponse +from .migration import _schema_check from .session import Session from .state import State @@ -111,41 +106,22 @@ def load_dialect_impl(self, dialect): return self.impl -class DynamicPickleType(TypeDecorator): - """Represents a type that can be pickled.""" - - impl = PickleType - - def load_dialect_impl(self, dialect): - if dialect.name == "mysql": - return dialect.type_descriptor(mysql.LONGBLOB) - if dialect.name == "spanner+spanner": - from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType - - return dialect.type_descriptor(SpannerPickleType) - return self.impl - - def process_bind_param(self, value, dialect): - """Ensures the pickled value is a bytes object before passing it to the database dialect.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return pickle.dumps(value) - return value - - def process_result_value(self, value, dialect): - """Ensures the raw bytes from the database are unpickled back into a Python object.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return pickle.loads(value) - return value - - class Base(DeclarativeBase): """Base class for database tables.""" pass +class StorageMetadata(Base): + """Represents internal metadata stored in the database.""" + + __tablename__ = "adk_internal_metadata" + key: Mapped[str] = mapped_column( + String(DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + value: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + + class StorageSession(Base): """Represents a session stored in the database.""" @@ -237,46 +213,10 @@ class StorageEvent(Base): ) invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) - author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) - actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) - long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( - Text, nullable=True - ) - branch: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True - ) timestamp: Mapped[PreciseTimestamp] = mapped_column( PreciseTimestamp, default=func.now() ) - - # === Fields from llm_response.py === - content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) - grounding_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - custom_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - usage_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - citation_metadata: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - - partial: Mapped[bool] = mapped_column(Boolean, nullable=True) - turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) - error_code: Mapped[str] = mapped_column( - String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True - ) - error_message: Mapped[str] = mapped_column(String(1024), nullable=True) - interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) - input_transcription: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) - output_transcription: Mapped[dict[str, Any]] = mapped_column( - DynamicJSON, nullable=True - ) + event_data: Mapped[dict[str, Any]] = mapped_column(DynamicJSON) storage_session: Mapped[StorageSession] = relationship( "StorageSession", @@ -291,102 +231,27 @@ class StorageEvent(Base): ), ) - @property - def long_running_tool_ids(self) -> set[str]: - return ( - set(json.loads(self.long_running_tool_ids_json)) - if self.long_running_tool_ids_json - else set() - ) - - @long_running_tool_ids.setter - def long_running_tool_ids(self, value: set[str]): - if value is None: - self.long_running_tool_ids_json = None - else: - self.long_running_tool_ids_json = json.dumps(list(value)) - @classmethod def from_event(cls, session: Session, event: Event) -> StorageEvent: - storage_event = StorageEvent( + """Creates a StorageEvent from an Event.""" + return StorageEvent( id=event.id, invocation_id=event.invocation_id, - author=event.author, - branch=event.branch, - actions=event.actions, session_id=session.id, app_name=session.app_name, user_id=session.user_id, timestamp=datetime.fromtimestamp(event.timestamp), - long_running_tool_ids=event.long_running_tool_ids, - partial=event.partial, - turn_complete=event.turn_complete, - error_code=event.error_code, - error_message=event.error_message, - interrupted=event.interrupted, + event_data=event.model_dump(exclude_none=True, mode="json"), ) - if event.content: - storage_event.content = event.content.model_dump( - exclude_none=True, mode="json" - ) - if event.grounding_metadata: - storage_event.grounding_metadata = event.grounding_metadata.model_dump( - exclude_none=True, mode="json" - ) - if event.custom_metadata: - storage_event.custom_metadata = event.custom_metadata - if event.usage_metadata: - storage_event.usage_metadata = event.usage_metadata.model_dump( - exclude_none=True, mode="json" - ) - if event.citation_metadata: - storage_event.citation_metadata = event.citation_metadata.model_dump( - exclude_none=True, mode="json" - ) - if event.input_transcription: - storage_event.input_transcription = event.input_transcription.model_dump( - exclude_none=True, mode="json" - ) - if event.output_transcription: - storage_event.output_transcription = ( - event.output_transcription.model_dump(exclude_none=True, mode="json") - ) - return storage_event def to_event(self) -> Event: - return Event( - id=self.id, - invocation_id=self.invocation_id, - author=self.author, - branch=self.branch, - # This is needed as previous ADK version pickled actions might not have - # value defined in the current version of the EventActions model. - actions=EventActions().model_copy(update=self.actions.model_dump()), - timestamp=self.timestamp.timestamp(), - long_running_tool_ids=self.long_running_tool_ids, - partial=self.partial, - turn_complete=self.turn_complete, - error_code=self.error_code, - error_message=self.error_message, - interrupted=self.interrupted, - custom_metadata=self.custom_metadata, - content=_session_util.decode_model(self.content, types.Content), - grounding_metadata=_session_util.decode_model( - self.grounding_metadata, types.GroundingMetadata - ), - usage_metadata=_session_util.decode_model( - self.usage_metadata, types.GenerateContentResponseUsageMetadata - ), - citation_metadata=_session_util.decode_model( - self.citation_metadata, types.CitationMetadata - ), - input_transcription=_session_util.decode_model( - self.input_transcription, types.Transcription - ), - output_transcription=_session_util.decode_model( - self.output_transcription, types.Transcription - ), - ) + """Converts the StorageEvent to an Event.""" + return Event.model_validate({ + **self.event_data, + "id": self.id, + "invocation_id": self.invocation_id, + "timestamp": self.timestamp.timestamp(), + }) class StorageAppState(Base): @@ -463,7 +328,6 @@ def __init__(self, db_url: str, **kwargs: Any): logger.info("Local timezone: %s", local_timezone) self.db_engine: AsyncEngine = db_engine - self.metadata: MetaData = MetaData() # DB session factory method self.database_session_factory: async_sessionmaker[ @@ -483,10 +347,46 @@ async def _ensure_tables_created(self): async with self._table_creation_lock: # Double-check after acquiring the lock if not self._tables_created: + # Check schema version BEFORE creating tables. + # This prevents creating metadata table on a v0.1 DB. + async with self.database_session_factory() as sql_session: + version, is_v01 = await sql_session.run_sync( + _schema_check.get_version_and_v01_status_sync + ) + + if is_v01: + raise RuntimeError( + "Database schema appears to be v0.1, but" + f" {_schema_check.CURRENT_SCHEMA_VERSION} is required. Please" + " migrate the database using 'adk migrate session'." + ) + elif version and version < _schema_check.CURRENT_SCHEMA_VERSION: + raise RuntimeError( + f"Database schema version is {version}, but current version is" + f" {_schema_check.CURRENT_SCHEMA_VERSION}. Please migrate" + " the database to the latest version using 'adk migrate" + " session'." + ) + async with self.db_engine.begin() as conn: # Uncomment to recreate DB every time # await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.create_all) + + # If we are here, DB is either new or >= current version. + # If new or without metadata row, stamp it as current version. + async with self.database_session_factory() as sql_session: + metadata = await sql_session.get( + StorageMetadata, _schema_check.SCHEMA_VERSION_KEY + ) + if not metadata: + sql_session.add( + StorageMetadata( + key=_schema_check.SCHEMA_VERSION_KEY, + value=_schema_check.CURRENT_SCHEMA_VERSION, + ) + ) + await sql_session.commit() self._tables_created = True @override @@ -723,7 +623,9 @@ async def append_event(self, session: Session, event: Event) -> Event: storage_session.state = storage_session.state | session_state_delta if storage_session._dialect_name == "sqlite": - update_time = datetime.utcfromtimestamp(event.timestamp) + update_time = datetime.fromtimestamp( + event.timestamp, timezone.utc + ).replace(tzinfo=None) else: update_time = datetime.fromtimestamp(event.timestamp) storage_session.update_time = update_time diff --git a/src/google/adk/sessions/migration/_schema_check.py b/src/google/adk/sessions/migration/_schema_check.py new file mode 100644 index 0000000000..f6fdc59956 --- /dev/null +++ b/src/google/adk/sessions/migration/_schema_check.py @@ -0,0 +1,114 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Database schema version check utility.""" + +from __future__ import annotations + +import logging + +import sqlalchemy +from sqlalchemy import create_engine as create_sync_engine +from sqlalchemy import inspect +from sqlalchemy import text + +logger = logging.getLogger("google_adk." + __name__) + +SCHEMA_VERSION_KEY = "schema_version" +SCHEMA_VERSION_0_1_PICKLE = "0.1" +SCHEMA_VERSION_1_0_JSON = "1.0" +CURRENT_SCHEMA_VERSION = "1.0" + + +def _to_sync_url(db_url: str) -> str: + """Removes +driver from SQLAlchemy URL.""" + if "://" in db_url: + scheme, _, rest = db_url.partition("://") + if "+" in scheme: + dialect = scheme.split("+", 1)[0] + return f"{dialect}://{rest}" + return db_url + + +def get_version_and_v01_status_sync( + sess: sqlalchemy.orm.Session, +) -> tuple[str | None, bool]: + """Returns (version, is_v01) inspecting the database.""" + inspector = sqlalchemy.inspect(sess.get_bind()) + if inspector.has_table("adk_internal_metadata"): + try: + result = sess.execute( + text("SELECT value FROM adk_internal_metadata WHERE key = :key"), + {"key": SCHEMA_VERSION_KEY}, + ).fetchone() + # If table exists, with or without key, it's 1.0 or newer. + return (result[0] if result else SCHEMA_VERSION_1_0_JSON), False + except Exception as e: + logger.warning( + "Could not read from adk_internal_metadata: %s. Assuming v1.0.", + e, + ) + return SCHEMA_VERSION_1_0_JSON, False + + if inspector.has_table("events"): + try: + cols = {c["name"] for c in inspector.get_columns("events")} + if "actions" in cols and "event_data" not in cols: + return None, True # 0.1 schema + except Exception as e: + logger.warning("Could not inspect 'events' table columns: %s", e) + return None, False # New DB + + +def get_db_schema_version(db_url: str) -> str | None: + """Reads schema version from DB. + + Checks metadata table first, falls back to table structure for 0.1 vs 1.0. + """ + engine = None + try: + engine = create_sync_engine(_to_sync_url(db_url)) + inspector = inspect(engine) + + if inspector.has_table("adk_internal_metadata"): + with engine.connect() as connection: + result = connection.execute( + text("SELECT value FROM adk_internal_metadata WHERE key = :key"), + parameters={"key": SCHEMA_VERSION_KEY}, + ).fetchone() + # If table exists, with or without key, it's 1.0 or newer. + return result[0] if result else SCHEMA_VERSION_1_0_JSON + + # Metadata table doesn't exist, check for 0.1 schema. + # 0.1 schema has an 'events' table with an 'actions' column. + if inspector.has_table("events"): + try: + cols = {c["name"] for c in inspector.get_columns("events")} + if "actions" in cols and "event_data" not in cols: + return SCHEMA_VERSION_0_1_PICKLE + except Exception as e: + logger.warning("Could not inspect 'events' table columns: %s", e) + + # If no metadata table and not identifiable as 0.1, + # assume it is a new/empty DB requiring schema 1.0. + return SCHEMA_VERSION_1_0_JSON + except Exception as e: + logger.info( + "Could not determine schema version by inspecting database: %s." + " Assuming v1.0.", + e, + ) + return SCHEMA_VERSION_1_0_JSON + finally: + if engine: + engine.dispose() diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py new file mode 100644 index 0000000000..f33ef3f5cf --- /dev/null +++ b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py @@ -0,0 +1,492 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Migration script from SQLAlchemy DB with Pickle Events to JSON schema.""" + +from __future__ import annotations + +import argparse +from datetime import datetime +from datetime import timezone +import json +import logging +import pickle +import sys +from typing import Any +from typing import Optional + +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions +from google.adk.sessions import _session_util +from google.adk.sessions import database_session_service as dss +from google.adk.sessions.migration import _schema_check +from google.genai import types +import sqlalchemy +from sqlalchemy import Boolean +from sqlalchemy import create_engine +from sqlalchemy import ForeignKeyConstraint +from sqlalchemy import func +from sqlalchemy import text +from sqlalchemy import Text +from sqlalchemy.dialects import mysql +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import sessionmaker +from sqlalchemy.types import PickleType +from sqlalchemy.types import String +from sqlalchemy.types import TypeDecorator + +logger = logging.getLogger("google_adk." + __name__) + + +# --- Old Schema Definitions --- +class DynamicPickleType(TypeDecorator): + """Represents a type that can be pickled.""" + + impl = PickleType + + def load_dialect_impl(self, dialect): + if dialect.name == "mysql": + return dialect.type_descriptor(mysql.LONGBLOB) + if dialect.name == "spanner+spanner": + from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType + + return dialect.type_descriptor(SpannerPickleType) + return self.impl + + def process_bind_param(self, value, dialect): + """Ensures the pickled value is a bytes object before passing it to the database dialect.""" + if value is not None: + if dialect.name in ("spanner+spanner", "mysql"): + return pickle.dumps(value) + return value + + def process_result_value(self, value, dialect): + """Ensures the raw bytes from the database are unpickled back into a Python object.""" + if value is not None: + if dialect.name in ("spanner+spanner", "mysql"): + return pickle.loads(value) + return value + + +class OldBase(DeclarativeBase): + pass + + +class OldStorageSession(OldBase): + __tablename__ = "sessions" + app_name: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + id: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(dss.DynamicJSON), default={} + ) + create_time: Mapped[datetime] = mapped_column( + dss.PreciseTimestamp, default=func.now() + ) + update_time: Mapped[datetime] = mapped_column( + dss.PreciseTimestamp, default=func.now(), onupdate=func.now() + ) + + +class OldStorageEvent(OldBase): + """Old storage event with pickle.""" + + __tablename__ = "events" + id: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + app_name: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + session_id: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + invocation_id: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_VARCHAR_LENGTH) + ) + author: Mapped[str] = mapped_column(String(dss.DEFAULT_MAX_VARCHAR_LENGTH)) + actions: Mapped[Any] = mapped_column(DynamicPickleType) + long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( + Text, nullable=True + ) + branch: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + timestamp: Mapped[datetime] = mapped_column( + dss.PreciseTimestamp, default=func.now() + ) + content: Mapped[dict[str, Any]] = mapped_column( + dss.DynamicJSON, nullable=True + ) + grounding_metadata: Mapped[dict[str, Any]] = mapped_column( + dss.DynamicJSON, nullable=True + ) + custom_metadata: Mapped[dict[str, Any]] = mapped_column( + dss.DynamicJSON, nullable=True + ) + usage_metadata: Mapped[dict[str, Any]] = mapped_column( + dss.DynamicJSON, nullable=True + ) + citation_metadata: Mapped[dict[str, Any]] = mapped_column( + dss.DynamicJSON, nullable=True + ) + partial: Mapped[bool] = mapped_column(Boolean, nullable=True) + turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) + error_code: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + error_message: Mapped[str] = mapped_column(String(1024), nullable=True) + interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) + input_transcription: Mapped[dict[str, Any]] = mapped_column( + dss.DynamicJSON, nullable=True + ) + output_transcription: Mapped[dict[str, Any]] = mapped_column( + dss.DynamicJSON, nullable=True + ) + __table_args__ = ( + ForeignKeyConstraint( + ["app_name", "user_id", "session_id"], + ["sessions.app_name", "sessions.user_id", "sessions.id"], + ondelete="CASCADE", + ), + ) + + @property + def long_running_tool_ids(self) -> set[str]: + return ( + set(json.loads(self.long_running_tool_ids_json)) + if self.long_running_tool_ids_json + else set() + ) + + +def _to_datetime_obj(val: Any) -> datetime | Any: + """Converts string to datetime if needed.""" + if isinstance(val, str): + try: + return datetime.strptime(val, "%Y-%m-%d %H:%M:%S.%f") + except ValueError: + try: + return datetime.strptime(val, "%Y-%m-%d %H:%M:%S") + except ValueError: + pass # return as is if not matching format + return val + + +def _row_to_event(row: dict) -> Event: + """Converts event row (dict) to event object, handling missing columns and deserializing.""" + + actions_val = row.get("actions") + actions = None + if actions_val is not None: + try: + if isinstance(actions_val, bytes): + actions = pickle.loads(actions_val) + else: # for spanner - it might return object directly + actions = actions_val + except Exception as e: + logger.warning( + f"Failed to unpickle actions for event {row.get('id')}: {e}" + ) + actions = None + + if actions and hasattr(actions, "model_dump"): + actions = EventActions().model_copy(update=actions.model_dump()) + elif isinstance(actions, dict): + actions = EventActions(**actions) + else: + actions = EventActions() + + def _safe_json_load(val): + data = None + if isinstance(val, str): + try: + data = json.loads(val) + except json.JSONDecodeError: + logger.warning(f"Failed to decode JSON for event {row.get('id')}") + return None + elif isinstance(val, dict): + data = val # for postgres JSONB + return data + + content_dict = _safe_json_load(row.get("content")) + grounding_metadata_dict = _safe_json_load(row.get("grounding_metadata")) + custom_metadata_dict = _safe_json_load(row.get("custom_metadata")) + usage_metadata_dict = _safe_json_load(row.get("usage_metadata")) + citation_metadata_dict = _safe_json_load(row.get("citation_metadata")) + input_transcription_dict = _safe_json_load(row.get("input_transcription")) + output_transcription_dict = _safe_json_load(row.get("output_transcription")) + + long_running_tool_ids_json = row.get("long_running_tool_ids_json") + long_running_tool_ids = set() + if long_running_tool_ids_json: + try: + long_running_tool_ids = set(json.loads(long_running_tool_ids_json)) + except json.JSONDecodeError: + logger.warning( + "Failed to decode long_running_tool_ids_json for event" + f" {row.get('id')}" + ) + long_running_tool_ids = set() + + event_id = row.get("id") + if not event_id: + raise ValueError("Event must have an id.") + timestamp = _to_datetime_obj(row.get("timestamp")) + if not timestamp: + raise ValueError(f"Event {event_id} must have a timestamp.") + + return Event( + id=event_id, + invocation_id=row.get("invocation_id", ""), + author=row.get("author", "agent"), + branch=row.get("branch"), + actions=actions, + timestamp=timestamp.replace(tzinfo=timezone.utc).timestamp(), + long_running_tool_ids=long_running_tool_ids, + partial=row.get("partial"), + turn_complete=row.get("turn_complete"), + error_code=row.get("error_code"), + error_message=row.get("error_message"), + interrupted=row.get("interrupted"), + custom_metadata=custom_metadata_dict, + content=_session_util.decode_model(content_dict, types.Content), + grounding_metadata=_session_util.decode_model( + grounding_metadata_dict, types.GroundingMetadata + ), + usage_metadata=_session_util.decode_model( + usage_metadata_dict, types.GenerateContentResponseUsageMetadata + ), + citation_metadata=_session_util.decode_model( + citation_metadata_dict, types.CitationMetadata + ), + input_transcription=_session_util.decode_model( + input_transcription_dict, types.Transcription + ), + output_transcription=_session_util.decode_model( + output_transcription_dict, types.Transcription + ), + ) + + +def _get_state_dict(state_val: Any) -> dict: + """Safely load dict from JSON string or return dict if already dict.""" + if isinstance(state_val, dict): + return state_val + if isinstance(state_val, str): + try: + return json.loads(state_val) + except json.JSONDecodeError: + logger.warning( + "Failed to parse state JSON string, defaulting to empty dict." + ) + return {} + return {} + + +class OldStorageAppState(OldBase): + __tablename__ = "app_states" + app_name: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(dss.DynamicJSON), default={} + ) + update_time: Mapped[datetime] = mapped_column( + dss.PreciseTimestamp, default=func.now(), onupdate=func.now() + ) + + +class OldStorageUserState(OldBase): + __tablename__ = "user_states" + app_name: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + user_id: Mapped[str] = mapped_column( + String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True + ) + state: Mapped[MutableDict[str, Any]] = mapped_column( + MutableDict.as_mutable(dss.DynamicJSON), default={} + ) + update_time: Mapped[datetime] = mapped_column( + dss.PreciseTimestamp, default=func.now(), onupdate=func.now() + ) + + +# --- Migration Logic --- +def migrate(source_db_url: str, dest_db_url: str): + """Migrates data from old pickle schema to new JSON schema.""" + logger.info(f"Connecting to source database: {source_db_url}") + try: + source_engine = create_engine(source_db_url) + SourceSession = sessionmaker(bind=source_engine) + except Exception as e: + logger.error(f"Failed to connect to source database: {e}") + raise RuntimeError(f"Failed to connect to source database: {e}") from e + + logger.info(f"Connecting to destination database: {dest_db_url}") + try: + dest_engine = create_engine(dest_db_url) + dss.Base.metadata.create_all(dest_engine) + DestSession = sessionmaker(bind=dest_engine) + except Exception as e: + logger.error(f"Failed to connect to destination database: {e}") + raise RuntimeError(f"Failed to connect to destination database: {e}") from e + + with SourceSession() as source_session, DestSession() as dest_session: + dest_session.merge( + dss.StorageMetadata( + key=_schema_check.SCHEMA_VERSION_KEY, + value=_schema_check.SCHEMA_VERSION_1_0_JSON, + ) + ) + dest_session.commit() + try: + inspector = sqlalchemy.inspect(source_engine) + + logger.info("Migrating app_states...") + if inspector.has_table("app_states"): + rows = ( + source_session.execute(text("SELECT * FROM app_states")) + .mappings() + .all() + ) + for row in rows: + dest_session.merge( + dss.StorageAppState( + app_name=row["app_name"], + state=_get_state_dict(row.get("state")), + update_time=_to_datetime_obj(row["update_time"]), + ) + ) + dest_session.commit() + logger.info(f"Migrated {len(rows)} app_states.") + else: + logger.info("No 'app_states' table found in source db.") + + logger.info("Migrating user_states...") + if inspector.has_table("user_states"): + rows = ( + source_session.execute(text("SELECT * FROM user_states")) + .mappings() + .all() + ) + for row in rows: + dest_session.merge( + dss.StorageUserState( + app_name=row["app_name"], + user_id=row["user_id"], + state=_get_state_dict(row.get("state")), + update_time=_to_datetime_obj(row["update_time"]), + ) + ) + dest_session.commit() + logger.info(f"Migrated {len(rows)} user_states.") + else: + logger.info("No 'user_states' table found in source db.") + + logger.info("Migrating sessions...") + if inspector.has_table("sessions"): + rows = ( + source_session.execute(text("SELECT * FROM sessions")) + .mappings() + .all() + ) + for row in rows: + dest_session.merge( + dss.StorageSession( + app_name=row["app_name"], + user_id=row["user_id"], + id=row["id"], + state=_get_state_dict(row.get("state")), + create_time=_to_datetime_obj(row["create_time"]), + update_time=_to_datetime_obj(row["update_time"]), + ) + ) + dest_session.commit() + logger.info(f"Migrated {len(rows)} sessions.") + else: + logger.info("No 'sessions' table found in source db.") + + logger.info("Migrating events...") + events = [] + if inspector.has_table("events"): + rows = ( + source_session.execute(text("SELECT * FROM events")) + .mappings() + .all() + ) + for row in rows: + try: + event_obj = _row_to_event(dict(row)) + new_event = dss.StorageEvent( + id=event_obj.id, + app_name=row["app_name"], + user_id=row["user_id"], + session_id=row["session_id"], + invocation_id=event_obj.invocation_id, + timestamp=datetime.fromtimestamp( + event_obj.timestamp, timezone.utc + ).replace(tzinfo=None), + event_data=event_obj.model_dump(mode="json", exclude_none=True), + ) + dest_session.merge(new_event) + events.append(new_event) + except Exception as e: + logger.warning( + f"Failed to migrate event row {row.get('id', 'N/A')}: {e}" + ) + dest_session.commit() + logger.info(f"Migrated {len(events)} events.") + else: + logger.info("No 'events' table found in source database.") + + logger.info("Migration completed successfully.") + except Exception as e: + logger.error(f"An error occurred during migration: {e}", exc_info=True) + dest_session.rollback() + raise RuntimeError(f"An error occurred during migration: {e}") from e + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=( + "Migrate ADK sessions from SQLAlchemy Pickle format to JSON format." + ) + ) + parser.add_argument( + "--source_db_url", required=True, help="SQLAlchemy URL of source database" + ) + parser.add_argument( + "--dest_db_url", + required=True, + help="SQLAlchemy URL of destination database", + ) + args = parser.parse_args() + try: + migrate(args.source_db_url, args.dest_db_url) + except Exception as e: + logger.error(f"Migration failed: {e}") + sys.exit(1) diff --git a/src/google/adk/sessions/migrate_from_sqlalchemy_sqlite.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py similarity index 100% rename from src/google/adk/sessions/migrate_from_sqlalchemy_sqlite.py rename to src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py diff --git a/src/google/adk/sessions/migration/migration_runner.py b/src/google/adk/sessions/migration/migration_runner.py new file mode 100644 index 0000000000..d7abbe41f9 --- /dev/null +++ b/src/google/adk/sessions/migration/migration_runner.py @@ -0,0 +1,128 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Migration runner to upgrade schemas to the latest version.""" + +from __future__ import annotations + +import logging +import os +import tempfile + +from google.adk.sessions.migration import _schema_check +from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle + +logger = logging.getLogger("google_adk." + __name__) + +# Migration map where key is start_version and value is +# (end_version, migration_function). +# Each key is a schema version, and its value is a tuple containing: +# (the schema version AFTER this migration step, the migration function to run). +# The migration function should accept (source_db_url, dest_db_url) as +# arguments. +MIGRATIONS = { + _schema_check.SCHEMA_VERSION_0_1_PICKLE: ( + _schema_check.SCHEMA_VERSION_1_0_JSON, + migrate_from_sqlalchemy_pickle.migrate, + ), +} +# The most recent schema version. The migration process stops once this version +# is reached. +LATEST_VERSION = _schema_check.CURRENT_SCHEMA_VERSION + + +def upgrade(source_db_url: str, dest_db_url: str): + """Migrates a database from its current version to the latest version. + + If the source database schema is older than the latest version, this + function applies migration scripts sequentially until the schema reaches the + LATEST_VERSION. + + If multiple migration steps are required, intermediate results are stored in + temporary SQLite database files. This means a multi-step migration + between other database types (e.g. PostgreSQL to PostgreSQL) will use + SQLite for intermediate steps. + + In-place migration (source_db_url == dest_db_url) is not supported, + as migrations always read from a source and write to a destination. + + Args: + source_db_url: The SQLAlchemy URL of the database to migrate from. + dest_db_url: The SQLAlchemy URL of the database to migrate to. This must be + different from source_db_url. + + Raises: + RuntimeError: If source_db_url and dest_db_url are the same, or if no + migration path is found. + """ + current_version = _schema_check.get_db_schema_version(source_db_url) + + if current_version == LATEST_VERSION: + logger.info( + f"Database {source_db_url} is already at latest version" + f" {LATEST_VERSION}. No migration needed." + ) + return + + if source_db_url == dest_db_url: + raise RuntimeError( + "In-place migration is not supported. " + "Please provide a different file for dest_db_url." + ) + + # Build the list of migration steps required to reach LATEST_VERSION. + migrations_to_run = [] + ver = current_version + while ver in MIGRATIONS and ver != LATEST_VERSION: + migrations_to_run.append(MIGRATIONS[ver]) + ver = MIGRATIONS[ver][0] + + if not migrations_to_run: + raise RuntimeError( + "Could not find migration path for schema version" + f" {current_version} to {LATEST_VERSION}." + ) + + temp_files = [] + in_url = source_db_url + try: + for i, (end_version, migrate_func) in enumerate(migrations_to_run): + is_last_step = i == len(migrations_to_run) - 1 + + if is_last_step: + out_url = dest_db_url + else: + # For intermediate steps, create a temporary SQLite DB to store the + # result. + fd, temp_path = tempfile.mkstemp(suffix=".db") + os.close(fd) + out_url = f"sqlite:///{temp_path}" + temp_files.append(temp_path) + logger.debug(f"Created temp db {out_url} for step {i+1}") + + logger.info( + f"Migrating from {in_url} to {out_url} (schema {end_version})..." + ) + migrate_func(in_url, out_url) + logger.info(f"Finished migration step to schema {end_version}.") + # The output of this step becomes the input for the next step. + in_url = out_url + finally: + # Ensure temporary files are cleaned up even if migration fails. + # Cleanup temp files + for path in temp_files: + try: + os.remove(path) + logger.debug(f"Removed temp db {path}") + except OSError as e: + logger.warning(f"Failed to remove temp db file {path}: {e}") diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index 8ba6531f52..e0d44b3872 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -107,7 +107,7 @@ def __init__(self, db_path: str): f"Database {db_path} seems to use an old schema." " Please run the migration command to" " migrate it to the new schema. Example: `python -m" - " google.adk.sessions.migrate_from_sqlalchemy_sqlite" + " google.adk.sessions.migration.migrate_from_sqlalchemy_sqlite" f" --source_db_path {db_path} --dest_db_path" f" {db_path}.new` then backup {db_path} and rename" f" {db_path}.new to {db_path}." diff --git a/tests/unittests/sessions/migration/test_migrations.py b/tests/unittests/sessions/migration/test_migrations.py new file mode 100644 index 0000000000..938387d29b --- /dev/null +++ b/tests/unittests/sessions/migration/test_migrations.py @@ -0,0 +1,106 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for migration scripts.""" + +from __future__ import annotations + +from datetime import datetime +from datetime import timezone + +from google.adk.events.event_actions import EventActions +from google.adk.sessions import database_session_service as dss +from google.adk.sessions.migration import _schema_check +from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle as mfsp +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + + +def test_migrate_from_sqlalchemy_pickle(tmp_path): + """Tests for migrate_from_sqlalchemy_pickle.""" + source_db_path = tmp_path / "source_pickle.db" + dest_db_path = tmp_path / "dest_json.db" + source_db_url = f"sqlite:///{source_db_path}" + dest_db_url = f"sqlite:///{dest_db_path}" + + # Setup source DB with old pickle schema + source_engine = create_engine(source_db_url) + mfsp.OldBase.metadata.create_all(source_engine) + SourceSession = sessionmaker(bind=source_engine) + source_session = SourceSession() + + # Populate source data + now = datetime.now(timezone.utc) + app_state = mfsp.OldStorageAppState( + app_name="app1", state={"akey": 1}, update_time=now + ) + user_state = mfsp.OldStorageUserState( + app_name="app1", user_id="user1", state={"ukey": 2}, update_time=now + ) + session = mfsp.OldStorageSession( + app_name="app1", + user_id="user1", + id="session1", + state={"skey": 3}, + create_time=now, + update_time=now, + ) + event = mfsp.OldStorageEvent( + id="event1", + app_name="app1", + user_id="user1", + session_id="session1", + invocation_id="invoke1", + author="user", + actions=EventActions(state_delta={"skey": 4}), + timestamp=now, + ) + source_session.add_all([app_state, user_state, session, event]) + source_session.commit() + source_session.close() + + mfsp.migrate(source_db_url, dest_db_url) + + # Verify destination DB + dest_engine = create_engine(dest_db_url) + DestSession = sessionmaker(bind=dest_engine) + dest_session = DestSession() + + metadata = dest_session.query(dss.StorageMetadata).first() + assert metadata is not None + assert metadata.key == _schema_check.SCHEMA_VERSION_KEY + assert metadata.value == _schema_check.SCHEMA_VERSION_1_0_JSON + + app_state_res = dest_session.query(dss.StorageAppState).first() + assert app_state_res is not None + assert app_state_res.app_name == "app1" + assert app_state_res.state == {"akey": 1} + + user_state_res = dest_session.query(dss.StorageUserState).first() + assert user_state_res is not None + assert user_state_res.user_id == "user1" + assert user_state_res.state == {"ukey": 2} + + session_res = dest_session.query(dss.StorageSession).first() + assert session_res is not None + assert session_res.id == "session1" + assert session_res.state == {"skey": 3} + + event_res = dest_session.query(dss.StorageEvent).first() + assert event_res is not None + assert event_res.id == "event1" + assert "state_delta" in event_res.event_data["actions"] + assert event_res.event_data["actions"]["state_delta"] == {"skey": 4} + + dest_session.close() diff --git a/tests/unittests/sessions/test_dynamic_pickle_type.py b/tests/unittests/sessions/test_dynamic_pickle_type.py deleted file mode 100644 index e4eb084f88..0000000000 --- a/tests/unittests/sessions/test_dynamic_pickle_type.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import pickle -from unittest import mock - -from google.adk.sessions.database_session_service import DynamicPickleType -import pytest -from sqlalchemy import create_engine -from sqlalchemy.dialects import mysql - - -@pytest.fixture -def pickle_type(): - """Fixture for DynamicPickleType instance.""" - return DynamicPickleType() - - -def test_load_dialect_impl_mysql(pickle_type): - """Test that MySQL dialect uses LONGBLOB.""" - # Mock the MySQL dialect - mock_dialect = mock.Mock() - mock_dialect.name = "mysql" - - # Mock the return value of type_descriptor - mock_longblob_type = mock.Mock() - mock_dialect.type_descriptor.return_value = mock_longblob_type - - impl = pickle_type.load_dialect_impl(mock_dialect) - - # Verify type_descriptor was called once with mysql.LONGBLOB - mock_dialect.type_descriptor.assert_called_once_with(mysql.LONGBLOB) - # Verify the return value is what we expect - assert impl == mock_longblob_type - - -def test_load_dialect_impl_spanner(pickle_type): - """Test that Spanner dialect uses SpannerPickleType.""" - # Mock the spanner dialect - mock_dialect = mock.Mock() - mock_dialect.name = "spanner+spanner" - - with mock.patch( - "google.cloud.sqlalchemy_spanner.sqlalchemy_spanner.SpannerPickleType" - ) as mock_spanner_type: - pickle_type.load_dialect_impl(mock_dialect) - mock_dialect.type_descriptor.assert_called_once_with(mock_spanner_type) - - -def test_load_dialect_impl_default(pickle_type): - """Test that other dialects use default PickleType.""" - engine = create_engine("sqlite:///:memory:") - dialect = engine.dialect - impl = pickle_type.load_dialect_impl(dialect) - # Should return the default impl (PickleType) - assert impl == pickle_type.impl - - -@pytest.mark.parametrize( - "dialect_name", - [ - pytest.param("mysql", id="mysql"), - pytest.param("spanner+spanner", id="spanner"), - ], -) -def test_process_bind_param_pickle_dialects(pickle_type, dialect_name): - """Test that MySQL and Spanner dialects pickle the value.""" - mock_dialect = mock.Mock() - mock_dialect.name = dialect_name - - test_data = {"key": "value", "nested": [1, 2, 3]} - result = pickle_type.process_bind_param(test_data, mock_dialect) - - # Should be pickled bytes - assert isinstance(result, bytes) - # Should be able to unpickle back to original - assert pickle.loads(result) == test_data - - -def test_process_bind_param_default(pickle_type): - """Test that other dialects return value as-is.""" - mock_dialect = mock.Mock() - mock_dialect.name = "sqlite" - - test_data = {"key": "value"} - result = pickle_type.process_bind_param(test_data, mock_dialect) - - # Should return value unchanged (SQLAlchemy's PickleType handles it) - assert result == test_data - - -def test_process_bind_param_none(pickle_type): - """Test that None values are handled correctly.""" - mock_dialect = mock.Mock() - mock_dialect.name = "mysql" - - result = pickle_type.process_bind_param(None, mock_dialect) - assert result is None - - -@pytest.mark.parametrize( - "dialect_name", - [ - pytest.param("mysql", id="mysql"), - pytest.param("spanner+spanner", id="spanner"), - ], -) -def test_process_result_value_pickle_dialects(pickle_type, dialect_name): - """Test that MySQL and Spanner dialects unpickle the value.""" - mock_dialect = mock.Mock() - mock_dialect.name = dialect_name - - test_data = {"key": "value", "nested": [1, 2, 3]} - pickled_data = pickle.dumps(test_data) - - result = pickle_type.process_result_value(pickled_data, mock_dialect) - - # Should be unpickled back to original - assert result == test_data - - -def test_process_result_value_default(pickle_type): - """Test that other dialects return value as-is.""" - mock_dialect = mock.Mock() - mock_dialect.name = "sqlite" - - test_data = {"key": "value"} - result = pickle_type.process_result_value(test_data, mock_dialect) - - # Should return value unchanged (SQLAlchemy's PickleType handles it) - assert result == test_data - - -def test_process_result_value_none(pickle_type): - """Test that None values are handled correctly.""" - mock_dialect = mock.Mock() - mock_dialect.name = "mysql" - - result = pickle_type.process_result_value(None, mock_dialect) - assert result is None - - -@pytest.mark.parametrize( - "dialect_name", - [ - pytest.param("mysql", id="mysql"), - pytest.param("spanner+spanner", id="spanner"), - ], -) -def test_roundtrip_pickle_dialects(pickle_type, dialect_name): - """Test full roundtrip for MySQL and Spanner: bind -> result.""" - mock_dialect = mock.Mock() - mock_dialect.name = dialect_name - - original_data = { - "string": "test", - "number": 42, - "list": [1, 2, 3], - "nested": {"a": 1, "b": 2}, - } - - # Simulate bind (Python -> DB) - bound_value = pickle_type.process_bind_param(original_data, mock_dialect) - assert isinstance(bound_value, bytes) - - # Simulate result (DB -> Python) - result_value = pickle_type.process_result_value(bound_value, mock_dialect) - assert result_value == original_data From 7edd7ea9b77fa19433f1e04a7eefccb47ab08b02 Mon Sep 17 00:00:00 2001 From: Ishan Raj Singh Date: Mon, 1 Dec 2025 10:51:07 -0800 Subject: [PATCH 44/63] chore: Add warning for full resource path in VertexAiMemoryBankService agent_engine_id This change updates the docstring for `agent_engine_id` to clarify that only the resource ID is expected. It also adds a warning log if the provided `agent_engine_id` contains a '/' character, suggesting it might be a full resource path, and provides guidance on how to extract the ID. Unit tests are added to verify the warning behavior. Merge: https://github.com/google/adk-python/pull/2941 Co-authored-by: George Weale PiperOrigin-RevId: 838845022 --- .../memory/vertex_ai_memory_bank_service.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/google/adk/memory/vertex_ai_memory_bank_service.py b/src/google/adk/memory/vertex_ai_memory_bank_service.py index b8f434c563..5df012e027 100644 --- a/src/google/adk/memory/vertex_ai_memory_bank_service.py +++ b/src/google/adk/memory/vertex_ai_memory_bank_service.py @@ -48,14 +48,15 @@ def __init__( Args: project: The project ID of the Memory Bank to use. location: The location of the Memory Bank to use. - agent_engine_id: The ID of the agent engine to use for the Memory Bank. + agent_engine_id: The ID of the agent engine to use for the Memory Bank, e.g. '456' in - 'projects/my-project/locations/us-central1/reasoningEngines/456'. + 'projects/my-project/locations/us-central1/reasoningEngines/456'. To + extract from api_resource.name, use: + ``agent_engine.api_resource.name.split('/')[-1]`` express_mode_api_key: The API key to use for Express Mode. If not provided, the API key from the GOOGLE_API_KEY environment variable will - be used. It will only be used if GOOGLE_GENAI_USE_VERTEXAI is true. - Do not use Google AI Studio API key for this field. For more details, - visit + be used. It will only be used if GOOGLE_GENAI_USE_VERTEXAI is true. Do + not use Google AI Studio API key for this field. For more details, visit https://cloud.google.com/vertex-ai/generative-ai/docs/start/express-mode/overview """ self._project = project @@ -65,6 +66,14 @@ def __init__( project, location, express_mode_api_key ) + if agent_engine_id and '/' in agent_engine_id: + logger.warning( + "agent_engine_id appears to be a full resource path: '%s'. " + "Expected just the ID (e.g., '456'). " + "Extract the ID using: agent_engine.api_resource.name.split('/')[-1]", + agent_engine_id, + ) + @override async def add_session_to_memory(self, session: Session): if not self._agent_engine_id: From 8e82838f1e59128edb3884315f1d94d6d3602632 Mon Sep 17 00:00:00 2001 From: Eitan Yarmush Date: Mon, 1 Dec 2025 11:04:13 -0800 Subject: [PATCH 45/63] fix: Refactor Anthropic integration to support both direct API and Vertex AI This change introduces an `AnthropicLlm` base class for direct Anthropic API calls using `AsyncAnthropic`. The existing `Claude` class now inherits from `AnthropicLlm` and is specialized to use `AsyncAnthropicVertex` for models hosted on Vertex AI. The `messages.create` call is now properly awaited Merge: https://github.com/google/adk-python/pull/2904 Co-authored-by: George Weale PiperOrigin-RevId: 838851026 --- src/google/adk/models/anthropic_llm.py | 33 +++++++++++++++----- tests/unittests/models/test_anthropic_llm.py | 32 +++++++++++++++++++ 2 files changed, 57 insertions(+), 8 deletions(-) diff --git a/src/google/adk/models/anthropic_llm.py b/src/google/adk/models/anthropic_llm.py index f965a9906d..163fbe4571 100644 --- a/src/google/adk/models/anthropic_llm.py +++ b/src/google/adk/models/anthropic_llm.py @@ -28,7 +28,8 @@ from typing import TYPE_CHECKING from typing import Union -from anthropic import AnthropicVertex +from anthropic import AsyncAnthropic +from anthropic import AsyncAnthropicVertex from anthropic import NOT_GIVEN from anthropic import types as anthropic_types from google.genai import types @@ -41,7 +42,7 @@ if TYPE_CHECKING: from .llm_request import LlmRequest -__all__ = ["Claude"] +__all__ = ["AnthropicLlm", "Claude"] logger = logging.getLogger("google_adk." + __name__) @@ -264,15 +265,15 @@ def function_declaration_to_tool_param( ) -class Claude(BaseLlm): - """Integration with Claude models served from Vertex AI. +class AnthropicLlm(BaseLlm): + """Integration with Claude models via the Anthropic API. Attributes: model: The name of the Claude model. max_tokens: The maximum number of tokens to generate. """ - model: str = "claude-3-5-sonnet-v2@20241022" + model: str = "claude-sonnet-4-20250514" max_tokens: int = 8192 @classmethod @@ -304,7 +305,7 @@ async def generate_content_async( else NOT_GIVEN ) # TODO(b/421255973): Enable streaming for anthropic models. - message = self._anthropic_client.messages.create( + message = await self._anthropic_client.messages.create( model=llm_request.model, system=llm_request.config.system_instruction, messages=messages, @@ -315,7 +316,23 @@ async def generate_content_async( yield message_to_generate_content_response(message) @cached_property - def _anthropic_client(self) -> AnthropicVertex: + def _anthropic_client(self) -> AsyncAnthropic: + return AsyncAnthropic() + + +class Claude(AnthropicLlm): + """Integration with Claude models served from Vertex AI. + + Attributes: + model: The name of the Claude model. + max_tokens: The maximum number of tokens to generate. + """ + + model: str = "claude-3-5-sonnet-v2@20241022" + + @cached_property + @override + def _anthropic_client(self) -> AsyncAnthropicVertex: if ( "GOOGLE_CLOUD_PROJECT" not in os.environ or "GOOGLE_CLOUD_LOCATION" not in os.environ @@ -325,7 +342,7 @@ def _anthropic_client(self) -> AnthropicVertex: " Anthropic on Vertex." ) - return AnthropicVertex( + return AsyncAnthropicVertex( project_id=os.environ["GOOGLE_CLOUD_PROJECT"], region=os.environ["GOOGLE_CLOUD_LOCATION"], ) diff --git a/tests/unittests/models/test_anthropic_llm.py b/tests/unittests/models/test_anthropic_llm.py index 13d615bc32..e1880abf0d 100644 --- a/tests/unittests/models/test_anthropic_llm.py +++ b/tests/unittests/models/test_anthropic_llm.py @@ -19,6 +19,7 @@ from anthropic import types as anthropic_types from google.adk import version as adk_version from google.adk.models import anthropic_llm +from google.adk.models.anthropic_llm import AnthropicLlm from google.adk.models.anthropic_llm import Claude from google.adk.models.anthropic_llm import content_to_message_param from google.adk.models.anthropic_llm import function_declaration_to_tool_param @@ -359,6 +360,37 @@ async def mock_coro(): assert responses[0].content.parts[0].text == "Hello, how can I help you?" +@pytest.mark.asyncio +async def test_anthropic_llm_generate_content_async( + llm_request, generate_content_response, generate_llm_response +): + anthropic_llm_instance = AnthropicLlm(model="claude-sonnet-4-20250514") + with mock.patch.object( + anthropic_llm_instance, "_anthropic_client" + ) as mock_client: + with mock.patch.object( + anthropic_llm, + "message_to_generate_content_response", + return_value=generate_llm_response, + ): + # Create a mock coroutine that returns the generate_content_response. + async def mock_coro(): + return generate_content_response + + # Assign the coroutine to the mocked method + mock_client.messages.create.return_value = mock_coro() + + responses = [ + resp + async for resp in anthropic_llm_instance.generate_content_async( + llm_request, stream=False + ) + ] + assert len(responses) == 1 + assert isinstance(responses[0], LlmResponse) + assert responses[0].content.parts[0].text == "Hello, how can I help you?" + + @pytest.mark.asyncio async def test_generate_content_async_with_max_tokens( llm_request, generate_content_response, generate_llm_response From 2a1a41d3ec60376aba14e5a0aa069e645dc121e1 Mon Sep 17 00:00:00 2001 From: Ankur Sharma Date: Mon, 1 Dec 2025 11:20:08 -0800 Subject: [PATCH 46/63] chore: Adding Eval Client label to model calls made during evals Co-authored-by: Ankur Sharma PiperOrigin-RevId: 838857867 --- .../adk/evaluation/local_eval_service.py | 36 +++++---- src/google/adk/models/google_llm.py | 27 +++---- src/google/adk/utils/_client_labels_utils.py | 78 +++++++++++++++++++ tests/unittests/models/test_google_llm.py | 38 +++++---- .../utils/test_client_labels_utils.py | 68 ++++++++++++++++ 5 files changed, 194 insertions(+), 53 deletions(-) create mode 100644 src/google/adk/utils/_client_labels_utils.py create mode 100644 tests/unittests/utils/test_client_labels_utils.py diff --git a/src/google/adk/evaluation/local_eval_service.py b/src/google/adk/evaluation/local_eval_service.py index 806a8d690d..30344702d2 100644 --- a/src/google/adk/evaluation/local_eval_service.py +++ b/src/google/adk/evaluation/local_eval_service.py @@ -31,6 +31,8 @@ from ..memory.base_memory_service import BaseMemoryService from ..sessions.base_session_service import BaseSessionService from ..sessions.in_memory_session_service import InMemorySessionService +from ..utils._client_labels_utils import client_label_context +from ..utils._client_labels_utils import EVAL_CLIENT_LABEL from ..utils.feature_decorator import experimental from .base_eval_service import BaseEvalService from .base_eval_service import EvaluateConfig @@ -249,11 +251,12 @@ async def _evaluate_single_inference_result( for eval_metric in evaluate_config.eval_metrics: # Perform evaluation of the metric. try: - evaluation_result = await self._evaluate_metric( - eval_metric=eval_metric, - actual_invocations=inference_result.inferences, - expected_invocations=eval_case.conversation, - ) + with client_label_context(EVAL_CLIENT_LABEL): + evaluation_result = await self._evaluate_metric( + eval_metric=eval_metric, + actual_invocations=inference_result.inferences, + expected_invocations=eval_case.conversation, + ) except Exception as e: # We intentionally catch the Exception as we don't want failures to # affect other metric evaluation. @@ -403,17 +406,18 @@ async def _perform_inference_single_eval_item( ) try: - inferences = ( - await EvaluationGenerator._generate_inferences_from_root_agent( - root_agent=root_agent, - user_simulator=self._user_simulator_provider.provide(eval_case), - initial_session=initial_session, - session_id=session_id, - session_service=self._session_service, - artifact_service=self._artifact_service, - memory_service=self._memory_service, - ) - ) + with client_label_context(EVAL_CLIENT_LABEL): + inferences = ( + await EvaluationGenerator._generate_inferences_from_root_agent( + root_agent=root_agent, + user_simulator=self._user_simulator_provider.provide(eval_case), + initial_session=initial_session, + session_id=session_id, + session_service=self._session_service, + artifact_service=self._artifact_service, + memory_service=self._memory_service, + ) + ) inference_result.inferences = inferences inference_result.status = InferenceStatus.SUCCESS diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 90c2fece76..93d802ecdc 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -19,8 +19,6 @@ import copy from functools import cached_property import logging -import os -import sys from typing import AsyncGenerator from typing import cast from typing import Optional @@ -31,7 +29,7 @@ from google.genai.errors import ClientError from typing_extensions import override -from .. import version +from ..utils._client_labels_utils import get_client_labels from ..utils.context_utils import Aclosing from ..utils.streaming_utils import StreamingResponseAggregator from ..utils.variant_utils import GoogleLLMVariant @@ -49,8 +47,7 @@ _NEW_LINE = '\n' _EXCLUDED_PART_FIELD = {'inline_data': {'data'}} -_AGENT_ENGINE_TELEMETRY_TAG = 'remote_reasoning_engine' -_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = 'GOOGLE_CLOUD_AGENT_ENGINE_ID' + _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE = """ On how to mitigate this issue, please refer to: @@ -245,7 +242,7 @@ def api_client(self) -> Client: return Client( http_options=types.HttpOptions( - headers=self._tracking_headers, + headers=self._tracking_headers(), retry_options=self.retry_options, ) ) @@ -258,16 +255,12 @@ def _api_backend(self) -> GoogleLLMVariant: else GoogleLLMVariant.GEMINI_API ) - @cached_property def _tracking_headers(self) -> dict[str, str]: - framework_label = f'google-adk/{version.__version__}' - if os.environ.get(_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME): - framework_label = f'{framework_label}+{_AGENT_ENGINE_TELEMETRY_TAG}' - language_label = 'gl-python/' + sys.version.split()[0] - version_header_value = f'{framework_label} {language_label}' + labels = get_client_labels() + header_value = ' '.join(labels) tracking_headers = { - 'x-goog-api-client': version_header_value, - 'user-agent': version_header_value, + 'x-goog-api-client': header_value, + 'user-agent': header_value, } return tracking_headers @@ -286,7 +279,7 @@ def _live_api_client(self) -> Client: return Client( http_options=types.HttpOptions( - headers=self._tracking_headers, api_version=self._live_api_version + headers=self._tracking_headers(), api_version=self._live_api_version ) ) @@ -310,7 +303,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: if not llm_request.live_connect_config.http_options.headers: llm_request.live_connect_config.http_options.headers = {} llm_request.live_connect_config.http_options.headers.update( - self._tracking_headers + self._tracking_headers() ) llm_request.live_connect_config.http_options.api_version = ( self._live_api_version @@ -397,7 +390,7 @@ async def _preprocess_request(self, llm_request: LlmRequest) -> None: def _merge_tracking_headers(self, headers: dict[str, str]) -> dict[str, str]: """Merge tracking headers to the given headers.""" headers = headers or {} - for key, tracking_header_value in self._tracking_headers.items(): + for key, tracking_header_value in self._tracking_headers().items(): custom_value = headers.get(key, None) if not custom_value: headers[key] = tracking_header_value diff --git a/src/google/adk/utils/_client_labels_utils.py b/src/google/adk/utils/_client_labels_utils.py new file mode 100644 index 0000000000..72858c3c1d --- /dev/null +++ b/src/google/adk/utils/_client_labels_utils.py @@ -0,0 +1,78 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager +import contextvars +import os +import sys +from typing import List + +from .. import version + +_ADK_LABEL = "google-adk" +_LANGUAGE_LABEL = "gl-python" +_AGENT_ENGINE_TELEMETRY_TAG = "remote_reasoning_engine" +_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME = "GOOGLE_CLOUD_AGENT_ENGINE_ID" + + +EVAL_CLIENT_LABEL = f"google-adk-eval/{version.__version__}" +"""Label used to denote calls emerging to external system as a part of Evals.""" + +# The ContextVar holds client label collected for the current request. +_LABEL_CONTEXT: contextvars.ContextVar[str] = contextvars.ContextVar( + "_LABEL_CONTEXT", default=None +) + + +def _get_default_labels() -> List[str]: + """Returns a list of labels that are always added.""" + framework_label = f"{_ADK_LABEL}/{version.__version__}" + + if os.environ.get(_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME): + framework_label = f"{framework_label}+{_AGENT_ENGINE_TELEMETRY_TAG}" + + language_label = f"{_LANGUAGE_LABEL}/" + sys.version.split()[0] + return [framework_label, language_label] + + +@contextmanager +def client_label_context(client_label: str): + """Runs the operation within the context of the given client label.""" + current_client_label = _LABEL_CONTEXT.get() + + if current_client_label is not None: + raise ValueError( + "Client label already exists. You can only add one client label." + ) + + token = _LABEL_CONTEXT.set(client_label) + + try: + yield + finally: + # Restore the previous state of the context variable + _LABEL_CONTEXT.reset(token) + + +def get_client_labels() -> List[str]: + """Returns the current list of client labels that can be added to HTTP Headers.""" + labels = _get_default_labels() + current_client_label = _LABEL_CONTEXT.get() + + if current_client_label: + labels.append(current_client_label) + + return labels diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index f2419daf3f..ddf1b07667 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -22,8 +22,6 @@ from google.adk.agents.context_cache_config import ContextCacheConfig from google.adk.models.cache_metadata import CacheMetadata from google.adk.models.gemini_llm_connection import GeminiLlmConnection -from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME -from google.adk.models.google_llm import _AGENT_ENGINE_TELEMETRY_TAG from google.adk.models.google_llm import _build_function_declaration_log from google.adk.models.google_llm import _build_request_log from google.adk.models.google_llm import _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE @@ -31,6 +29,8 @@ from google.adk.models.google_llm import Gemini from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse +from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME +from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types from google.genai.errors import ClientError @@ -142,13 +142,6 @@ def llm_request_with_computer_use(): ) -@pytest.fixture -def mock_os_environ(): - initial_env = os.environ.copy() - with mock.patch.dict(os.environ, initial_env, clear=False) as m: - yield m - - def test_supported_models(): models = Gemini.supported_models() assert len(models) == 4 @@ -193,12 +186,15 @@ def test_client_version_header(): ) -def test_client_version_header_with_agent_engine(mock_os_environ): - os.environ[_AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME] = "my_test_project" +def test_client_version_header_with_agent_engine(monkeypatch): + monkeypatch.setenv( + _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME, "my_test_project" + ) model = Gemini(model="gemini-1.5-flash") client = model.api_client - # Check that ADK version with telemetry tag and Python version are present in headers + # Check that ADK version with telemetry tag and Python version are present in + # headers adk_version_with_telemetry = ( f"google-adk/{adk_version.__version__}+{_AGENT_ENGINE_TELEMETRY_TAG}" ) @@ -473,8 +469,9 @@ async def test_generate_content_async_with_custom_headers( """Test that tracking headers are updated when custom headers are provided.""" # Add custom headers to the request config custom_headers = {"custom-header": "custom-value"} - for key in gemini_llm._tracking_headers: - custom_headers[key] = "custom " + gemini_llm._tracking_headers[key] + tracking_headers = gemini_llm._tracking_headers() + for key in tracking_headers: + custom_headers[key] = "custom " + tracking_headers[key] llm_request.config.http_options = types.HttpOptions(headers=custom_headers) with mock.patch.object(gemini_llm, "api_client") as mock_client: @@ -497,8 +494,9 @@ async def mock_coro(): config_arg = call_args.kwargs["config"] for key, value in config_arg.http_options.headers.items(): - if key in gemini_llm._tracking_headers: - assert value == gemini_llm._tracking_headers[key] + " custom" + tracking_headers = gemini_llm._tracking_headers() + if key in tracking_headers: + assert value == tracking_headers[key] + " custom" else: assert value == custom_headers[key] @@ -547,7 +545,7 @@ async def mock_coro(): config_arg = call_args.kwargs["config"] expected_headers = custom_headers.copy() - expected_headers.update(gemini_llm._tracking_headers) + expected_headers.update(gemini_llm._tracking_headers()) assert config_arg.http_options.headers == expected_headers assert len(responses) == 2 @@ -601,7 +599,7 @@ async def mock_coro(): assert final_config.http_options is not None assert ( final_config.http_options.headers["x-goog-api-client"] - == gemini_llm._tracking_headers["x-goog-api-client"] + == gemini_llm._tracking_headers()["x-goog-api-client"] ) assert len(responses) == 2 if stream else 1 @@ -635,7 +633,7 @@ def test_live_api_client_properties(gemini_llm): assert http_options.api_version == "v1beta1" # Check that tracking headers are included - tracking_headers = gemini_llm._tracking_headers + tracking_headers = gemini_llm._tracking_headers() for key, value in tracking_headers.items(): assert key in http_options.headers assert value in http_options.headers[key] @@ -673,7 +671,7 @@ async def __aexit__(self, *args): # Verify that tracking headers were merged with custom headers expected_headers = custom_headers.copy() - expected_headers.update(gemini_llm._tracking_headers) + expected_headers.update(gemini_llm._tracking_headers()) assert config_arg.http_options.headers == expected_headers # Verify that API version was set diff --git a/tests/unittests/utils/test_client_labels_utils.py b/tests/unittests/utils/test_client_labels_utils.py new file mode 100644 index 0000000000..b1d6acb001 --- /dev/null +++ b/tests/unittests/utils/test_client_labels_utils.py @@ -0,0 +1,68 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +from google.adk import version +from google.adk.utils import _client_labels_utils +import pytest + + +def test_get_client_labels_default(): + """Test get_client_labels returns default labels.""" + labels = _client_labels_utils.get_client_labels() + assert len(labels) == 2 + assert f"google-adk/{version.__version__}" == labels[0] + assert f"gl-python/{sys.version.split()[0]}" == labels[1] + + +def test_get_client_labels_with_agent_engine_id(monkeypatch): + """Test get_client_labels returns agent engine tag when env var is set.""" + monkeypatch.setenv( + _client_labels_utils._AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME, + "test-agent-id", + ) + labels = _client_labels_utils.get_client_labels() + assert len(labels) == 2 + assert ( + f"google-adk/{version.__version__}+{_client_labels_utils._AGENT_ENGINE_TELEMETRY_TAG}" + == labels[0] + ) + assert f"gl-python/{sys.version.split()[0]}" == labels[1] + + +def test_get_client_labels_with_context(): + """Test get_client_labels includes label from context.""" + with _client_labels_utils.client_label_context("my-label/1.0"): + labels = _client_labels_utils.get_client_labels() + assert len(labels) == 3 + assert f"google-adk/{version.__version__}" == labels[0] + assert f"gl-python/{sys.version.split()[0]}" == labels[1] + assert "my-label/1.0" == labels[2] + + +def test_client_label_context_nested_error(): + """Test client_label_context raises error when nested.""" + with pytest.raises(ValueError, match="Client label already exists"): + with _client_labels_utils.client_label_context("my-label/1.0"): + with _client_labels_utils.client_label_context("another-label/1.0"): + pass + + +def test_eval_client_label(): + """Test EVAL_CLIENT_LABEL has correct format.""" + assert ( + f"google-adk-eval/{version.__version__}" + == _client_labels_utils.EVAL_CLIENT_LABEL + ) From cb19d0714c90cd578551753680f39d8d6076c79b Mon Sep 17 00:00:00 2001 From: Rohit Yanamadala Date: Mon, 1 Dec 2025 12:25:22 -0800 Subject: [PATCH 47/63] fix: Optimize Stale Agent with GraphQL and Search API to resolve 429 Quota errors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Merge https://github.com/google/adk-python/pull/3700 ### Description This PR refactors the `adk_stale_agent` to address `429 RESOURCE_EXHAUSTED` errors encountered during workflow execution. The previous implementation was inefficient in fetching issue history (using pagination over the REST API) and lacked server-side filtering, causing excessive API calls and huge token consumption that breached Gemini API quotas. The new implementation switches to a **GraphQL-first approach**, implements server-side filtering via the Search API, adds robust concurrency controls, and significantly improves code maintainability through modular refactoring. ### Root Cause of Failure The previous workflow failed with the following error due to passing too much context to the LLM and processing too many irrelevant issues: ```text google.genai.errors.ClientError: 429 RESOURCE_EXHAUSTED. Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_paid_tier_input_token_count ``` ### Key Changes #### 1. Optimization: REST → GraphQL (`agent.py`) * **Old:** Fetched issue comments and timeline events using multiple paginated REST API calls (`/timeline`). * **New:** Implemented `get_issue_state` using a single **GraphQL** query. This fetches comments, `userContentEdits`, and specific timeline events (Labels, Renames) in one network request. * **Refactoring:** The complex analysis logic has been decomposed into focused helper functions (_fetch_graphql_data, _build_history_timeline, _replay_history_to_find_state) for better readability and testing. * **Configurable:** Added GRAPHQL_COMMENT_LIMIT and GRAPHQL_TIMELINE_LIMIT settings to tune context depth * **Impact:** Drastically reduces the data payload size and eliminates multiple API round-trips, significantly lowering the token count sent to the LLM. #### 2. Optimization: Server-Side Filtering (`utils.py`) * **Old:** Fetched *all* open issues via REST and filtered them in Python memory. * **New:** Uses the GitHub Search API (`get_old_open_issue_numbers`) with `created: COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3700 from ryanaiagent:feat/improve-stale-agent 888064eff125ae74f7c3a9ad6c74f98de80243a2 PiperOrigin-RevId: 838885530 --- .github/workflows/stale-bot.yml | 26 +- .../adk_stale_agent/PROMPT_INSTRUCTION.txt | 100 ++- .../samples/adk_stale_agent/README.md | 86 ++- contributing/samples/adk_stale_agent/agent.py | 657 +++++++++++------- contributing/samples/adk_stale_agent/main.py | 189 ++++- .../samples/adk_stale_agent/settings.py | 22 +- contributing/samples/adk_stale_agent/utils.py | 243 ++++++- 7 files changed, 930 insertions(+), 393 deletions(-) diff --git a/.github/workflows/stale-bot.yml b/.github/workflows/stale-bot.yml index 882cb7b432..6948b56459 100644 --- a/.github/workflows/stale-bot.yml +++ b/.github/workflows/stale-bot.yml @@ -1,57 +1,43 @@ -# .github/workflows/stale-issue-auditor.yml - -# Best Practice: Always have a 'name' field at the top. name: ADK Stale Issue Auditor -# The 'on' block defines the triggers. on: - # The 'workflow_dispatch' trigger allows manual runs. workflow_dispatch: - # The 'schedule' trigger runs the bot on a timer. schedule: - # This runs at 6:00 AM UTC (e.g., 10 PM PST). + # This runs at 6:00 AM UTC (10 PM PST) - cron: '0 6 * * *' -# The 'jobs' block contains the work to be done. jobs: - # A unique ID for the job. audit-stale-issues: - # The runner environment. runs-on: ubuntu-latest + timeout-minutes: 60 - # Permissions for the job's temporary GITHUB_TOKEN. - # These are standard and syntactically correct. permissions: issues: write contents: read - # The sequence of steps for the job. steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.11' - name: Install dependencies - # The '|' character allows for multi-line shell commands. run: | python -m pip install --upgrade pip pip install requests google-adk - name: Run Auditor Agent Script - # The 'env' block for setting environment variables. env: GITHUB_TOKEN: ${{ secrets.ADK_TRIAGE_AGENT }} GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }} - OWNER: google + OWNER: ${{ github.repository_owner }} REPO: adk-python - ISSUES_PER_RUN: 100 + CONCURRENCY_LIMIT: 3 LLM_MODEL_NAME: "gemini-2.5-flash" PYTHONPATH: contributing/samples - # The final 'run' command. run: python -m adk_stale_agent.main \ No newline at end of file diff --git a/contributing/samples/adk_stale_agent/PROMPT_INSTRUCTION.txt b/contributing/samples/adk_stale_agent/PROMPT_INSTRUCTION.txt index bb31889b23..8f5f585ff6 100644 --- a/contributing/samples/adk_stale_agent/PROMPT_INSTRUCTION.txt +++ b/contributing/samples/adk_stale_agent/PROMPT_INSTRUCTION.txt @@ -1,40 +1,68 @@ -You are a highly intelligent and transparent repository auditor for '{OWNER}/{REPO}'. -Your job is to analyze all open issues and report on your findings before taking any action. +You are a highly intelligent repository auditor for '{OWNER}/{REPO}'. +Your job is to analyze a specific issue and report findings before taking action. **Primary Directive:** Ignore any events from users ending in `[bot]`. -**Reporting Directive:** For EVERY issue you analyze, you MUST output a concise, human-readable summary, starting with "Analysis for Issue #[number]:". +**Reporting Directive:** Output a concise summary starting with "Analysis for Issue #[number]:". + +**THRESHOLDS:** +- Stale Threshold: {stale_threshold_days} days. +- Close Threshold: {close_threshold_days} days. **WORKFLOW:** -1. **Context Gathering**: Call `get_repository_maintainers` and `get_all_open_issues`. -2. **Per-Issue Analysis**: For each issue, call `get_issue_state`, passing in the maintainers list. -3. **Decision & Reporting**: Based on the summary from `get_issue_state`, follow this strict decision tree in order. - ---- **DECISION TREE & REPORTING TEMPLATES** --- - -**STEP 1: CHECK FOR ACTIVITY (IS THE ISSUE ACTIVE?)** -- **Condition**: Was the last human action NOT from a maintainer? (i.e., `last_human_commenter_is_maintainer` is `False`). -- **Action**: The author or a third party has acted. The issue is ACTIVE. - - **Report and Action**: If '{STALE_LABEL_NAME}' is present, report: "Analysis for Issue #[number]: Issue is ACTIVE. The last action was a [action type] by a non-maintainer. To get the [action type], you MUST use the value from the 'last_human_action_type' field in the summary you received from the tool." Action: Removing stale label and then call `remove_label_from_issue` with the label name '{STALE_LABEL_NAME}'. Otherwise, report: "Analysis for Issue #[number]: Issue is ACTIVE. No stale label to remove. Action: None." -- **If this condition is met, stop processing this issue.** - -**STEP 2: IF PENDING, MANAGE THE STALE LIFECYCLE.** -- **Condition**: The last human action WAS from a maintainer (`last_human_commenter_is_maintainer` is `True`). The issue is PENDING. -- **Action**: You must now determine the correct state. - - - **First, check if the issue is already STALE.** - - **Condition**: Is the `'{STALE_LABEL_NAME}'` label present in `current_labels`? - - **Action**: The issue is STALE. Your only job is to check if it should be closed. - - **Get Time Difference**: Call `calculate_time_difference` with the `stale_label_applied_at` timestamp. - - **Decision & Report**: If `hours_passed` > **{CLOSE_HOURS_AFTER_STALE_THRESHOLD}**: Report "Analysis for Issue #[number]: STALE. Close threshold met ({CLOSE_HOURS_AFTER_STALE_THRESHOLD} hours) with no author activity." Action: Closing issue and then call `close_as_stale`. Otherwise, report "Analysis for Issue #[number]: STALE. Close threshold not yet met. Action: None." - - - **ELSE (the issue is PENDING but not yet stale):** - - **Analyze Intent**: Semantically analyze the `last_maintainer_comment_text`. Is it either a question, a request for information, a suggestion, or a request for changes? - - **If YES (it is either a question, a request for information, a suggestion, or a request for changes)**: - - **CRITICAL CHECK**: Now, you must verify the author has not already responded. Compare the `last_author_event_time` and the `last_maintainer_comment_time`. - - **IF the author has NOT responded** (i.e., `last_author_event_time` is older than `last_maintainer_comment_time` or is null): - - **Get Time Difference**: Call `calculate_time_difference` with the `last_maintainer_comment_time`. - - **Decision & Report**: If `hours_passed` > **{STALE_HOURS_THRESHOLD}**: Report "Analysis for Issue #[number]: PENDING. Stale threshold met ({STALE_HOURS_THRESHOLD} hours)." Action: Marking as stale and then call `add_stale_label_and_comment` and if label name '{REQUEST_CLARIFICATION_LABEL}' is missing then call `add_label_to_issue` with the label name '{REQUEST_CLARIFICATION_LABEL}'. Otherwise, report: "Analysis for Issue #[number]: PENDING. Stale threshold not met. Action: None." - - **ELSE (the author HAS responded)**: - - **Report**: "Analysis for Issue #[number]: PENDING, but author has already responded to the last maintainer request. Action: None." - - **If NO (it is not a request):** - - **Report**: "Analysis for Issue #[number]: PENDING. Maintainer's last comment was not a request. Action: None." \ No newline at end of file +1. **Context Gathering**: Call `get_issue_state`. +2. **Decision**: Follow this strict decision tree using the data returned by the tool. + +--- **DECISION TREE** --- + +**STEP 1: CHECK IF ALREADY STALE** +- **Condition**: Is `is_stale` (from tool) **True**? +- **Action**: + - **Check Role**: Look at `last_action_role`. + + - **IF 'author' OR 'other_user'**: + - **Context**: The user has responded. The issue is now ACTIVE. + - **Action 1**: Call `remove_label_from_issue` with '{STALE_LABEL_NAME}'. + - **Action 2 (ALERT CHECK)**: Look at `maintainer_alert_needed`. + - **IF True**: User edited description silently. + -> **Action**: Call `alert_maintainer_of_edit`. + - **IF False**: User commented normally. No alert needed. + - **Report**: "Analysis for Issue #[number]: ACTIVE. User activity detected. Removed stale label." + + - **IF 'maintainer'**: + - **Check Time**: Check `days_since_stale_label`. + - **If `days_since_stale_label` > {close_threshold_days}**: + - **Action**: Call `close_as_stale`. + - **Report**: "Analysis for Issue #[number]: STALE. Close threshold met. Closing." + - **Else**: + - **Report**: "Analysis for Issue #[number]: STALE. Waiting for close threshold. No action." + +**STEP 2: CHECK IF ACTIVE (NOT STALE)** +- **Condition**: `is_stale` is **False**. +- **Action**: + - **Check Role**: If `last_action_role` is 'author' or 'other_user': + - **Context**: The issue is Active. + - **Action (ALERT CHECK)**: Look at `maintainer_alert_needed`. + - **IF True**: The user edited the description silently, and we haven't alerted yet. + -> **Action**: Call `alert_maintainer_of_edit`. + -> **Report**: "Analysis for Issue #[number]: ACTIVE. Silent update detected (Description Edit). Alerted maintainer." + - **IF False**: + -> **Report**: "Analysis for Issue #[number]: ACTIVE. Last action was by user. No action." + + - **Check Role**: If `last_action_role` is 'maintainer': + - **Proceed to STEP 3.** + +**STEP 3: ANALYZE MAINTAINER INTENT** +- **Context**: The last person to act was a Maintainer. +- **Action**: Read the text in `last_comment_text`. + - **Question Check**: Does the text ask a question, request clarification, ask for logs, or suggest trying a fix? + - **Time Check**: Is `days_since_activity` > {stale_threshold_days}? + + - **DECISION**: + - **IF (Question == YES) AND (Time == YES)**: + - **Action**: Call `add_stale_label_and_comment`. + - **Check**: If '{REQUEST_CLARIFICATION_LABEL}' is not in `current_labels`, call `add_label_to_issue` for it. + - **Report**: "Analysis for Issue #[number]: STALE. Maintainer asked question [days_since_activity] days ago. Marking stale." + - **IF (Question == YES) BUT (Time == NO)**: + - **Report**: "Analysis for Issue #[number]: PENDING. Maintainer asked question, but threshold not met yet. No action." + - **IF (Question == NO)** (e.g., "I am working on this"): + - **Report**: "Analysis for Issue #[number]: ACTIVE. Maintainer gave status update (not a question). No action." \ No newline at end of file diff --git a/contributing/samples/adk_stale_agent/README.md b/contributing/samples/adk_stale_agent/README.md index 17b427d77c..afc47b11cc 100644 --- a/contributing/samples/adk_stale_agent/README.md +++ b/contributing/samples/adk_stale_agent/README.md @@ -1,65 +1,89 @@ # ADK Stale Issue Auditor Agent -This directory contains an autonomous agent designed to audit a GitHub repository for stale issues, helping to maintain repository hygiene and ensure that all open items are actionable. +This directory contains an autonomous, **GraphQL-powered** agent designed to audit a GitHub repository for stale issues. It maintains repository hygiene by ensuring all open items are actionable and responsive. -The agent operates as a "Repository Auditor," proactively scanning all open issues rather than waiting for a specific trigger. It uses a combination of deterministic Python tools and the semantic understanding of a Large Language Model (LLM) to make intelligent decisions about the state of a conversation. +Unlike traditional "Stale Bots" that only look at timestamps, this agent uses a **Unified History Trace** and an **LLM (Large Language Model)** to understand the *context* of a conversation. It distinguishes between a maintainer asking a question (stale candidate) vs. a maintainer providing a status update (active). --- ## Core Logic & Features -The agent's primary goal is to identify issues where a maintainer has requested information from the author, and to manage the lifecycle of that issue based on the author's response (or lack thereof). +The agent operates as a "Repository Auditor," proactively scanning open issues using a high-efficiency decision tree. -**The agent follows a precise decision tree:** +### 1. Smart State Verification (GraphQL) +Instead of making multiple expensive API calls, the agent uses a single **GraphQL** query per issue to reconstruct the entire history of the conversation. It combines: +* **Comments** +* **Description/Body Edits** ("Ghost Edits") +* **Title Renames** +* **State Changes** (Reopens) -1. **Audits All Open Issues:** On each run, the agent fetches a batch of the oldest open issues in the repository. -2. **Identifies Pending Issues:** It analyzes the full timeline of each issue to see if the last human action was a comment from a repository maintainer. -3. **Semantic Intent Analysis:** If the last comment was from a maintainer, the agent uses the LLM to determine if the comment was a **question or a request for clarification**. -4. **Marks as Stale:** If the maintainer's question has gone unanswered by the author for a configurable period (e.g., 7 days), the agent will: - * Apply a `stale` label to the issue. - * Post a comment notifying the author that the issue is now considered stale and will be closed if no further action is taken. - * Proactively add a `request clarification` label if it's missing, to make the issue's state clear. -5. **Handles Activity:** If any non-maintainer (the author or a third party) comments on an issue, the agent will automatically remove the `stale` label, marking the issue as active again. -6. **Closes Stale Issues:** If an issue remains in the `stale` state for another configurable period (e.g., 7 days) with no new activity, the agent will post a final comment and close the issue. +It sorts these events chronologically to determine the **Last Active Actor**. -### Self-Configuration +### 2. The "Last Actor" Rule +The agent follows a precise logic flow based on who acted last: -A key feature of this agent is its ability to self-configure. It does not require a hard-coded list of maintainer usernames. On each run, it uses the GitHub API to dynamically fetch the list of users with write access to the repository, ensuring its logic is always based on the current team. +* **If Author/User acted last:** The issue is **ACTIVE**. + * This includes comments, title changes, and *silent* description edits. + * **Action:** The agent immediately removes the `stale` label. + * **Silent Update Alert:** If the user edited the description but *did not* comment, the agent posts a specific alert: *"Notification: The author has updated the issue description..."* to ensure maintainers are notified (since GitHub does not trigger notifications for body edits). + * **Spam Prevention:** The agent checks if it has already alerted about a specific silent edit to avoid spamming the thread. + +* **If Maintainer acted last:** The issue is **POTENTIALLY STALE**. + * The agent passes the text of the maintainer's last comment to the LLM. + +### 3. Semantic Intent Analysis (LLM) +If the maintainer was the last person to speak, the LLM analyzes the comment text to determine intent: +* **Question/Request:** "Can you provide logs?" / "Please try v2.0." + * **Verdict:** **STALE** (Waiting on Author). + * **Action:** If the time threshold is met, the agent adds the `stale` label. It also checks for the `request clarification` label and adds it if missing. +* **Status Update:** "We are working on a fix." / "Added to backlog." + * **Verdict:** **ACTIVE** (Waiting on Maintainer). + * **Action:** No action taken. The issue remains open without stale labels. + +### 4. Lifecycle Management +* **Marking Stale:** After `STALE_HOURS_THRESHOLD` (default: 7 days) of inactivity following a maintainer's question. +* **Closing:** After `CLOSE_HOURS_AFTER_STALE_THRESHOLD` (default: 7 days) of continued inactivity while marked stale. + +--- + +## Performance & Safety + +* **GraphQL Optimized:** Fetches comments, edits, labels, and timeline events in a single network request to minimize latency and API quota usage. +* **Search API Filtering:** Uses the GitHub Search API to pre-filter issues created recently, ensuring the bot doesn't waste cycles analyzing brand-new issues. +* **Rate Limit Aware:** Includes intelligent sleeping and retry logic (exponential backoff) to handle GitHub API rate limits (HTTP 429) gracefully. +* **Execution Metrics:** Logs the time taken and API calls consumed for every issue processed. --- ## Configuration -The agent is configured entirely via environment variables, which should be set as secrets in the GitHub Actions workflow environment. +The agent is configured via environment variables, typically set as secrets in GitHub Actions. ### Required Secrets | Secret Name | Description | | :--- | :--- | -| `GITHUB_TOKEN` | A GitHub Personal Access Token (PAT) with the required permissions. It's recommended to use a PAT from a dedicated "bot" account. -| `GOOGLE_API_KEY` | An API key for the Google AI (Gemini) model used for the agent's reasoning. - -### Required PAT Permissions - -The `GITHUB_TOKEN` requires the following **Repository Permissions**: -* **Issues**: `Read & write` (to read issues, add labels, comment, and close) -* **Administration**: `Read-only` (to read the list of repository collaborators/maintainers) +| `GITHUB_TOKEN` | A GitHub Personal Access Token (PAT) or Service Account Token with `repo` scope. | +| `GOOGLE_API_KEY` | An API key for the Google AI (Gemini) model used for reasoning. | ### Optional Configuration -These environment variables can be set in the workflow file to override the defaults in `settings.py`. +These variables control the timing thresholds and model selection. | Variable Name | Description | Default | | :--- | :--- | :--- | -| `STALE_HOURS_THRESHOLD` | The number of hours of inactivity after a maintainer's question before an issue is marked as `stale`. | `168` (7 days) | -| `CLOSE_HOURS_AFTER_STALE_THRESHOLD` | The number of hours after being marked `stale` before an issue is closed. | `168` (7 days) | -| `ISSUES_PER_RUN` | The maximum number of oldest open issues to process in a single workflow run. | `100` | -| `LLM_MODEL_NAME`| LLM model to use. | `gemini-2.5-flash` | +| `STALE_HOURS_THRESHOLD` | Hours of inactivity after a maintainer's question before marking as `stale`. | `168` (7 days) | +| `CLOSE_HOURS_AFTER_STALE_THRESHOLD` | Hours after being marked `stale` before the issue is closed. | `168` (7 days) | +| `LLM_MODEL_NAME`| The specific Gemini model version to use. | `gemini-2.5-flash` | +| `OWNER` | Repository owner (auto-detected in Actions). | (Environment dependent) | +| `REPO` | Repository name (auto-detected in Actions). | (Environment dependent) | --- ## Deployment -To deploy this agent, a GitHub Actions workflow file (`.github/workflows/stale-bot.yml`) is included. This workflow runs on a daily schedule and executes the agent's main script. +To deploy this agent, a GitHub Actions workflow file (`.github/workflows/stale-bot.yml`) is recommended. + +### Directory Structure Note +Because this agent resides within the `adk-python` package structure, the workflow must ensure the script is executed correctly to handle imports. -Ensure the necessary repository secrets are configured and the `stale` and `request clarification` labels exist in the repository. \ No newline at end of file diff --git a/contributing/samples/adk_stale_agent/agent.py b/contributing/samples/adk_stale_agent/agent.py index abcb128288..5235e0352f 100644 --- a/contributing/samples/adk_stale_agent/agent.py +++ b/contributing/samples/adk_stale_agent/agent.py @@ -17,10 +17,16 @@ import logging import os from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple from adk_stale_agent.settings import CLOSE_HOURS_AFTER_STALE_THRESHOLD from adk_stale_agent.settings import GITHUB_BASE_URL -from adk_stale_agent.settings import ISSUES_PER_RUN +from adk_stale_agent.settings import GRAPHQL_COMMENT_LIMIT +from adk_stale_agent.settings import GRAPHQL_EDIT_LIMIT +from adk_stale_agent.settings import GRAPHQL_TIMELINE_LIMIT from adk_stale_agent.settings import LLM_MODEL_NAME from adk_stale_agent.settings import OWNER from adk_stale_agent.settings import REPO @@ -38,20 +44,75 @@ logger = logging.getLogger("google_adk." + __name__) -# --- Primary Tools for the Agent --- +# --- Constants --- +# Used to detect if the bot has already posted an alert to avoid spamming. +BOT_ALERT_SIGNATURE = ( + "**Notification:** The author has updated the issue description" +) + +# --- Global Cache --- +_MAINTAINERS_CACHE: Optional[List[str]] = None + + +def _get_cached_maintainers() -> List[str]: + """ + Fetches the list of repository maintainers. + + This function relies on `utils.get_request` for network resilience. + `get_request` is configured with an HTTPAdapter that automatically performs + exponential backoff retries (up to 6 times) for 5xx errors and rate limits. + + If the retries are exhausted or the data format is invalid, this function + raises a RuntimeError to prevent the bot from running with incorrect permissions. + + Returns: + List[str]: A list of GitHub usernames with push access. + + Raises: + RuntimeError: If the API fails after all retries or returns invalid data. + """ + global _MAINTAINERS_CACHE + if _MAINTAINERS_CACHE is not None: + return _MAINTAINERS_CACHE + + logger.info("Initializing Maintainers Cache...") + + try: + url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/collaborators" + params = {"permission": "push"} + + data = get_request(url, params) + + if isinstance(data, list): + _MAINTAINERS_CACHE = [u["login"] for u in data if "login" in u] + logger.info(f"Cached {len(_MAINTAINERS_CACHE)} maintainers.") + return _MAINTAINERS_CACHE + else: + logger.error( + f"Invalid API response format: Expected list, got {type(data)}" + ) + raise ValueError(f"GitHub API returned non-list data: {data}") + + except Exception as e: + logger.critical( + f"FATAL: Failed to verify repository maintainers. Error: {e}" + ) + raise RuntimeError( + "Maintainer verification failed. processing aborted." + ) from e def load_prompt_template(filename: str) -> str: - """Loads the prompt text file from the same directory as this script. + """ + Loads the raw text content of a prompt file. Args: - filename: The name of the prompt file to load. + filename (str): The name of the file (e.g., 'PROMPT_INSTRUCTION.txt'). Returns: - The content of the file as a string. + str: The file content. """ file_path = os.path.join(os.path.dirname(__file__), filename) - with open(file_path, "r") as f: return f.read() @@ -59,300 +120,399 @@ def load_prompt_template(filename: str) -> str: PROMPT_TEMPLATE = load_prompt_template("PROMPT_INSTRUCTION.txt") -def get_repository_maintainers() -> dict[str, Any]: +def _fetch_graphql_data(item_number: int) -> Dict[str, Any]: """ - Fetches the list of repository collaborators with 'push' (write) access or higher. - This should only be called once per run. + Executes the GraphQL query to fetch raw issue data, including comments, + edits, and timeline events. + + Args: + item_number (int): The GitHub issue number. Returns: - A dictionary with the status and a list of maintainer usernames, or an - error dictionary. - """ - logger.debug("Fetching repository maintainers with push access...") - try: - url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/collaborators" - params = {"permission": "push"} - collaborators_data = get_request(url, params) + Dict[str, Any]: The raw 'issue' object from the GraphQL response. - maintainers = [user["login"] for user in collaborators_data] - logger.info(f"Found {len(maintainers)} repository maintainers.") - logger.debug(f"Maintainer list: {maintainers}") + Raises: + RequestException: If the GraphQL query returns errors or the issue is not found. + """ + query = """ + query($owner: String!, $name: String!, $number: Int!, $commentLimit: Int!, $timelineLimit: Int!) { + repository(owner: $owner, name: $name) { + issue(number: $number) { + author { login } + createdAt + labels(first: 20) { nodes { name } } + + comments(last: $commentLimit) { + nodes { + author { login } + body + createdAt + lastEditedAt + } + } + + userContentEdits(last: $editLimit) { + nodes { + editor { login } + editedAt + } + } + + timelineItems(itemTypes: [LABELED_EVENT, RENAMED_TITLE_EVENT, REOPENED_EVENT], last: $timelineLimit) { + nodes { + __typename + ... on LabeledEvent { + createdAt + actor { login } + label { name } + } + ... on RenamedTitleEvent { + createdAt + actor { login } + } + ... on ReopenedEvent { + createdAt + actor { login } + } + } + } + } + } + } + """ + + variables = { + "owner": OWNER, + "name": REPO, + "number": item_number, + "commentLimit": GRAPHQL_COMMENT_LIMIT, + "editLimit": GRAPHQL_EDIT_LIMIT, + "timelineLimit": GRAPHQL_TIMELINE_LIMIT, + } - return {"status": "success", "maintainers": maintainers} - except RequestException as e: - logger.error(f"Failed to fetch repository maintainers: {e}", exc_info=True) - return error_response(f"Error fetching repository maintainers: {e}") + response = post_request( + f"{GITHUB_BASE_URL}/graphql", {"query": query, "variables": variables} + ) + if "errors" in response: + raise RequestException(f"GraphQL Error: {response['errors'][0]['message']}") -def get_all_open_issues() -> dict[str, Any]: - """Fetches a batch of the oldest open issues for an audit. + data = response.get("data", {}).get("repository", {}).get("issue", {}) + if not data: + raise RequestException(f"Issue #{item_number} not found.") - Returns: - A dictionary containing the status and a list of open issues, or an error - dictionary. - """ - logger.info( - f"Fetching a batch of {ISSUES_PER_RUN} oldest open issues for audit..." - ) - url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues" - params = { - "state": "open", - "sort": "created", - "direction": "asc", - "per_page": ISSUES_PER_RUN, - } - try: - items = get_request(url, params) - logger.info(f"Found {len(items)} open issues to audit.") - return {"status": "success", "items": items} - except RequestException as e: - logger.error(f"Failed to fetch open issues: {e}", exc_info=True) - return error_response(f"Error fetching all open issues: {e}") + return data -def get_issue_state(item_number: int, maintainers: list[str]) -> dict[str, Any]: - """Analyzes an issue's complete history to create a comprehensive state summary. +def _build_history_timeline( + data: Dict[str, Any], +) -> Tuple[List[Dict[str, Any]], List[datetime], Optional[datetime]]: + """ + Parses raw GraphQL data into a unified, chronologically sorted history list. + Also extracts specific event times needed for logic checks. - This function acts as the primary "detective" for the agent. It performs the - complex, deterministic work of fetching and parsing an issue's full history, - allowing the LLM agent to focus on high-level semantic decision-making. + Args: + data (Dict[str, Any]): The raw issue data from `_fetch_graphql_data`. - It is designed to be highly robust by fetching the complete, multi-page history - from the GitHub `/timeline` API. By handling pagination correctly, it ensures - that even issues with a very long history (more than 100 events) are analyzed - in their entirety, preventing incorrect decisions based on incomplete data. + Returns: + Tuple[List[Dict], List[datetime], Optional[datetime]]: + - history: A list of normalized event dictionaries sorted by time. + - label_events: A list of timestamps when the stale label was applied. + - last_bot_alert_time: Timestamp of the last bot silent-edit alert (if any). + """ + issue_author = data.get("author", {}).get("login") + history = [] + label_events = [] + last_bot_alert_time = None + + # 1. Baseline: Issue Creation + history.append({ + "type": "created", + "actor": issue_author, + "time": dateutil.parser.isoparse(data["createdAt"]), + "data": None, + }) + + # 2. Process Comments + for c in data.get("comments", {}).get("nodes", []): + if not c: + continue + + actor = c.get("author", {}).get("login") + c_body = c.get("body", "") + c_time = dateutil.parser.isoparse(c.get("createdAt")) + + # Track bot alerts for spam prevention + if BOT_ALERT_SIGNATURE in c_body: + if last_bot_alert_time is None or c_time > last_bot_alert_time: + last_bot_alert_time = c_time + + if actor and not actor.endswith("[bot]"): + # Use edit time if available, otherwise creation time + e_time = c.get("lastEditedAt") + actual_time = dateutil.parser.isoparse(e_time) if e_time else c_time + history.append({ + "type": "commented", + "actor": actor, + "time": actual_time, + "data": c_body, + }) + + # 3. Process Body Edits ("Ghost Edits") + for e in data.get("userContentEdits", {}).get("nodes", []): + if not e: + continue + actor = e.get("editor", {}).get("login") + if actor and not actor.endswith("[bot]"): + history.append({ + "type": "edited_description", + "actor": actor, + "time": dateutil.parser.isoparse(e.get("editedAt")), + "data": None, + }) + + # 4. Process Timeline Events + for t in data.get("timelineItems", {}).get("nodes", []): + if not t: + continue + + etype = t.get("__typename") + actor = t.get("actor", {}).get("login") + time_val = dateutil.parser.isoparse(t.get("createdAt")) + + if etype == "LabeledEvent": + if t.get("label", {}).get("name") == STALE_LABEL_NAME: + label_events.append(time_val) + continue + + if actor and not actor.endswith("[bot]"): + pretty_type = ( + "renamed_title" if etype == "RenamedTitleEvent" else "reopened" + ) + history.append({ + "type": pretty_type, + "actor": actor, + "time": time_val, + "data": None, + }) + + # Sort chronologically + history.sort(key=lambda x: x["time"]) + return history, label_events, last_bot_alert_time + + +def _replay_history_to_find_state( + history: List[Dict[str, Any]], maintainers: List[str], issue_author: str +) -> Dict[str, Any]: + """ + Replays the unified event history to determine the absolute last actor and their role. Args: - item_number (int): The number of the GitHub issue or pull request to analyze. - maintainers (list[str]): A dynamically fetched list of GitHub usernames to be - considered maintainers. This is used to categorize actors found in - the issue's history. + history (List[Dict]): Chronologically sorted list of events. + maintainers (List[str]): List of maintainer usernames. + issue_author (str): Username of the issue author. Returns: - A dictionary that serves as a clean, factual report summarizing the - issue's state. On failure, it returns a dictionary with an 'error' status. + Dict[str, Any]: A dictionary containing the last state of the issue: + - last_action_role (str): 'author', 'maintainer', or 'other_user'. + - last_activity_time (datetime): Timestamp of the last human action. + - last_action_type (str): The type of the last action (e.g., 'commented'). + - last_comment_text (Optional[str]): The text of the last comment. """ - try: - # Fetch core issue data and prepare for timeline fetching. - logger.debug(f"Fetching full timeline for issue #{item_number}...") - issue_url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}" - issue_data = get_request(issue_url) - - # Fetch All pages from the timeline API to build a complete history. - timeline_url_base = f"{issue_url}/timeline" - timeline_data = [] - page = 1 - - while True: - paginated_url = f"{timeline_url_base}?per_page=100&page={page}" - logger.debug(f"Fetching timeline page {page} for issue #{item_number}...") - events_page = get_request(paginated_url) - if not events_page: - break - timeline_data.extend(events_page) - if len(events_page) < 100: - break - page += 1 + last_action_role = "author" + last_activity_time = history[0]["time"] + last_action_type = "created" + last_comment_text = None + + for event in history: + actor = event["actor"] + etype = event["type"] + + role = "other_user" + if actor == issue_author: + role = "author" + elif actor in maintainers: + role = "maintainer" + + last_action_role = role + last_activity_time = event["time"] + last_action_type = etype + + # Only store text if it was a comment (resets on other events like labels/edits) + if etype == "commented": + last_comment_text = event["data"] + else: + last_comment_text = None + + return { + "last_action_role": last_action_role, + "last_activity_time": last_activity_time, + "last_action_type": last_action_type, + "last_comment_text": last_comment_text, + } - logger.debug( - f"Fetched a total of {len(timeline_data)} timeline events across" - f" {page-1} page(s) for issue #{item_number}." - ) - # Initialize key variables for the analysis. - issue_author = issue_data.get("user", {}).get("login") - current_labels = [label["name"] for label in issue_data.get("labels", [])] +def get_issue_state(item_number: int) -> Dict[str, Any]: + """ + Retrieves the comprehensive state of a GitHub issue using GraphQL. - # Filter and sort all events into a clean, chronological history of human activity. - human_events = [] - for event in timeline_data: - actor = event.get("actor", {}).get("login") - timestamp_str = event.get("created_at") or event.get("submitted_at") + This function orchestrates the fetching, parsing, and analysis of the issue's + history to determine if it is stale, active, or pending maintainer review. - if not actor or not timestamp_str or actor.endswith("[bot]"): - continue + Args: + item_number (int): The GitHub issue number. - event["parsed_time"] = dateutil.parser.isoparse(timestamp_str) - human_events.append(event) + Returns: + Dict[str, Any]: A comprehensive state dictionary for the LLM agent. + Contains keys such as 'last_action_role', 'is_stale', 'days_since_activity', + and 'maintainer_alert_needed'. + """ + try: + maintainers = _get_cached_maintainers() - human_events.sort(key=lambda e: e["parsed_time"]) + # 1. Fetch + raw_data = _fetch_graphql_data(item_number) - # Find the most recent, relevant events by iterating backwards. - last_maintainer_comment = None - stale_label_event_time = None + issue_author = raw_data.get("author", {}).get("login") + labels_list = [ + l["name"] for l in raw_data.get("labels", {}).get("nodes", []) + ] - for event in reversed(human_events): - if ( - not last_maintainer_comment - and event.get("actor", {}).get("login") in maintainers - and event.get("event") == "commented" - ): - last_maintainer_comment = event + # 2. Parse & Sort + history, label_events, last_bot_alert_time = _build_history_timeline( + raw_data + ) + # 3. Analyze (Replay) + state = _replay_history_to_find_state(history, maintainers, issue_author) + + # 4. Final Calculations & Alert Logic + current_time = datetime.now(timezone.utc) + days_since_activity = ( + current_time - state["last_activity_time"] + ).total_seconds() / 86400 + + # Stale Checks + is_stale = STALE_LABEL_NAME in labels_list + days_since_stale_label = 0.0 + if is_stale and label_events: + latest_label_time = max(label_events) + days_since_stale_label = ( + current_time - latest_label_time + ).total_seconds() / 86400 + + # Silent Edit Alert Logic + maintainer_alert_needed = False + if ( + state["last_action_role"] in ["author", "other_user"] + and state["last_action_type"] == "edited_description" + ): if ( - not stale_label_event_time - and event.get("event") == "labeled" - and event.get("label", {}).get("name") == STALE_LABEL_NAME + last_bot_alert_time + and last_bot_alert_time > state["last_activity_time"] ): - stale_label_event_time = event["parsed_time"] - - if last_maintainer_comment and stale_label_event_time: - break - - last_author_action = next( - ( - e - for e in reversed(human_events) - if e.get("actor", {}).get("login") == issue_author - ), - None, - ) + logger.info( + f"#{item_number}: Silent edit detected, but Bot already alerted. No" + " spam." + ) + else: + maintainer_alert_needed = True + logger.info(f"#{item_number}: Silent edit detected. Alert needed.") - # Build and return the final summary report for the LLM agent. - last_human_event = human_events[-1] if human_events else None - last_human_actor = ( - last_human_event.get("actor", {}).get("login") - if last_human_event - else None + logger.debug( + f"#{item_number} VERDICT: Role={state['last_action_role']}, " + f"Idle={days_since_activity:.2f}d" ) return { "status": "success", - "issue_author": issue_author, - "current_labels": current_labels, - "last_maintainer_comment_text": ( - last_maintainer_comment.get("body") - if last_maintainer_comment - else None - ), - "last_maintainer_comment_time": ( - last_maintainer_comment["parsed_time"].isoformat() - if last_maintainer_comment - else None - ), - "last_author_event_time": ( - last_author_action["parsed_time"].isoformat() - if last_author_action - else None - ), - "last_author_action_type": ( - last_author_action.get("event") if last_author_action else "unknown" - ), - "last_human_action_type": ( - last_human_event.get("event") if last_human_event else "unknown" - ), - "last_human_commenter_is_maintainer": ( - last_human_actor in maintainers if last_human_actor else False - ), - "stale_label_applied_at": ( - stale_label_event_time.isoformat() - if stale_label_event_time - else None - ), + "last_action_role": state["last_action_role"], + "last_action_type": state["last_action_type"], + "maintainer_alert_needed": maintainer_alert_needed, + "is_stale": is_stale, + "days_since_activity": days_since_activity, + "days_since_stale_label": days_since_stale_label, + "last_comment_text": state["last_comment_text"], + "current_labels": labels_list, + "stale_threshold_days": STALE_HOURS_THRESHOLD / 24, + "close_threshold_days": CLOSE_HOURS_AFTER_STALE_THRESHOLD / 24, } except RequestException as e: + return error_response(f"Network Error: {e}") + except Exception as e: logger.error( - f"Failed to fetch comprehensive issue state for #{item_number}: {e}", - exc_info=True, - ) - return error_response( - f"Error getting comprehensive issue state for #{item_number}: {e}" + f"Unexpected error analyzing #{item_number}: {e}", exc_info=True ) + return error_response(f"Analysis Error: {e}") -def calculate_time_difference(timestamp_str: str) -> dict[str, Any]: - """Calculates the difference in hours between a UTC timestamp string and now. +# --- Tool Definitions --- - Args: - timestamp_str: An ISO 8601 formatted timestamp string. - Returns: - A dictionary with the status and the time difference in hours, or an error - dictionary. +def _format_days(hours: float) -> str: """ - try: - if not timestamp_str: - return error_response("Input timestamp is empty.") - event_time = dateutil.parser.isoparse(timestamp_str) - current_time_utc = datetime.now(timezone.utc) - time_difference = current_time_utc - event_time - hours_passed = time_difference.total_seconds() / 3600 - return {"status": "success", "hours_passed": hours_passed} - except (dateutil.parser.ParserError, TypeError) as e: - logger.error( - "Error calculating time difference for timestamp" - f" '{timestamp_str}': {e}", - exc_info=True, - ) - return error_response(f"Error calculating time difference: {e}") + Formats a duration in hours into a clean day string. + + Example: + 168.0 -> "7" + 12.0 -> "0.5" + """ + days = hours / 24 + return f"{days:.1f}" if days % 1 != 0 else f"{int(days)}" def add_label_to_issue(item_number: int, label_name: str) -> dict[str, Any]: - """Adds a specific label to an issue. + """ + Adds a label to the issue. Args: - item_number: The issue number. - label_name: The name of the label to add. - - Returns: - A dictionary indicating the status of the operation. + item_number (int): The GitHub issue number. + label_name (str): The name of the label to add. """ logger.debug(f"Adding label '{label_name}' to issue #{item_number}.") url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/labels" try: post_request(url, [label_name]) - logger.info( - f"Successfully added label '{label_name}' to issue #{item_number}." - ) return {"status": "success"} except RequestException as e: - logger.error(f"Failed to add '{label_name}' to issue #{item_number}: {e}") return error_response(f"Error adding label: {e}") def remove_label_from_issue( item_number: int, label_name: str ) -> dict[str, Any]: - """Removes a specific label from an issue or PR. + """ + Removes a label from the issue. Args: - item_number: The issue number. - label_name: The name of the label to remove. - - Returns: - A dictionary indicating the status of the operation. + item_number (int): The GitHub issue number. + label_name (str): The name of the label to remove. """ logger.debug(f"Removing label '{label_name}' from issue #{item_number}.") url = f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/labels/{label_name}" try: delete_request(url) - logger.info( - f"Successfully removed label '{label_name}' from issue #{item_number}." - ) return {"status": "success"} except RequestException as e: - logger.error( - f"Failed to remove '{label_name}' from issue #{item_number}: {e}" - ) return error_response(f"Error removing label: {e}") def add_stale_label_and_comment(item_number: int) -> dict[str, Any]: - """Adds the 'stale' label to an issue and posts a comment explaining why. + """ + Marks the issue as stale with a comment and label. Args: - item_number: The issue number. - - Returns: - A dictionary indicating the status of the operation. + item_number (int): The GitHub issue number. """ - logger.debug(f"Adding stale label and comment to issue #{item_number}.") + stale_days_str = _format_days(STALE_HOURS_THRESHOLD) + close_days_str = _format_days(CLOSE_HOURS_AFTER_STALE_THRESHOLD) + comment = ( "This issue has been automatically marked as stale because it has not" - " had recent activity after a maintainer requested clarification. It" - " will be closed if no further activity occurs within" - f" {CLOSE_HOURS_AFTER_STALE_THRESHOLD / 24:.0f} days." + f" had recent activity for {stale_days_str} days after a maintainer" + " requested clarification. It will be closed if no further activity" + f" occurs within {close_days_str} days." ) try: post_request( @@ -363,28 +523,42 @@ def add_stale_label_and_comment(item_number: int) -> dict[str, Any]: f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/labels", [STALE_LABEL_NAME], ) - logger.info(f"Successfully marked issue #{item_number} as stale.") return {"status": "success"} except RequestException as e: - logger.error( - f"Failed to mark issue #{item_number} as stale: {e}", exc_info=True - ) return error_response(f"Error marking issue as stale: {e}") -def close_as_stale(item_number: int) -> dict[str, Any]: - """Posts a final comment and closes an issue or PR as stale. +def alert_maintainer_of_edit(item_number: int) -> dict[str, Any]: + """ + Posts a comment alerting maintainers of a silent description update. Args: - item_number: The issue number. + item_number (int): The GitHub issue number. + """ + # Uses the constant signature to ensure detection logic in get_issue_state works. + comment = f"{BOT_ALERT_SIGNATURE}. Maintainers, please review." + try: + post_request( + f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}/comments", + {"body": comment}, + ) + return {"status": "success"} + except RequestException as e: + return error_response(f"Error posting alert: {e}") - Returns: - A dictionary indicating the status of the operation. + +def close_as_stale(item_number: int) -> dict[str, Any]: + """ + Closes the issue as not planned/stale. + + Args: + item_number (int): The GitHub issue number. """ - logger.debug(f"Closing issue #{item_number} as stale.") + days_str = _format_days(CLOSE_HOURS_AFTER_STALE_THRESHOLD) + comment = ( "This has been automatically closed because it has been marked as stale" - f" for over {CLOSE_HOURS_AFTER_STALE_THRESHOLD / 24:.0f} days." + f" for over {days_str} days." ) try: post_request( @@ -395,40 +569,29 @@ def close_as_stale(item_number: int) -> dict[str, Any]: f"{GITHUB_BASE_URL}/repos/{OWNER}/{REPO}/issues/{item_number}", {"state": "closed"}, ) - logger.info(f"Successfully closed issue #{item_number} as stale.") return {"status": "success"} except RequestException as e: - logger.error( - f"Failed to close issue #{item_number} as stale: {e}", exc_info=True - ) return error_response(f"Error closing issue: {e}") -# --- Agent Definition --- - root_agent = Agent( model=LLM_MODEL_NAME, name="adk_repository_auditor_agent", - description=( - "Audits open issues to manage their state based on conversation" - " history." - ), + description="Audits open issues.", instruction=PROMPT_TEMPLATE.format( OWNER=OWNER, REPO=REPO, STALE_LABEL_NAME=STALE_LABEL_NAME, REQUEST_CLARIFICATION_LABEL=REQUEST_CLARIFICATION_LABEL, - STALE_HOURS_THRESHOLD=STALE_HOURS_THRESHOLD, - CLOSE_HOURS_AFTER_STALE_THRESHOLD=CLOSE_HOURS_AFTER_STALE_THRESHOLD, + stale_threshold_days=STALE_HOURS_THRESHOLD / 24, + close_threshold_days=CLOSE_HOURS_AFTER_STALE_THRESHOLD / 24, ), tools=[ add_label_to_issue, add_stale_label_and_comment, - calculate_time_difference, + alert_maintainer_of_edit, close_as_stale, - get_all_open_issues, get_issue_state, - get_repository_maintainers, remove_label_from_issue, ], ) diff --git a/contributing/samples/adk_stale_agent/main.py b/contributing/samples/adk_stale_agent/main.py index f6fba3fba0..d4fe58dd63 100644 --- a/contributing/samples/adk_stale_agent/main.py +++ b/contributing/samples/adk_stale_agent/main.py @@ -15,10 +15,17 @@ import asyncio import logging import time +from typing import Tuple from adk_stale_agent.agent import root_agent +from adk_stale_agent.settings import CONCURRENCY_LIMIT from adk_stale_agent.settings import OWNER from adk_stale_agent.settings import REPO +from adk_stale_agent.settings import SLEEP_BETWEEN_CHUNKS +from adk_stale_agent.settings import STALE_HOURS_THRESHOLD +from adk_stale_agent.utils import get_api_call_count +from adk_stale_agent.utils import get_old_open_issue_numbers +from adk_stale_agent.utils import reset_api_call_count from google.adk.cli.utils import logs from google.adk.runners import InMemoryRunner from google.genai import types @@ -26,49 +33,163 @@ logs.setup_adk_logger(level=logging.INFO) logger = logging.getLogger("google_adk." + __name__) -APP_NAME = "adk_stale_agent_app" -USER_ID = "adk_stale_agent_user" +APP_NAME = "stale_bot_app" +USER_ID = "stale_bot_user" -async def main(): - """Initializes and runs the stale issue agent.""" - logger.info("--- Starting Stale Agent Run ---") - runner = InMemoryRunner(agent=root_agent, app_name=APP_NAME) - session = await runner.session_service.create_session( - user_id=USER_ID, app_name=APP_NAME - ) +async def process_single_issue(issue_number: int) -> Tuple[float, int]: + """ + Processes a single GitHub issue using the AI agent and logs execution metrics. + + Args: + issue_number (int): The GitHub issue number to audit. + + Returns: + Tuple[float, int]: A tuple containing: + - duration (float): Time taken to process the issue in seconds. + - api_calls (int): The number of API calls made during this specific execution. + + Raises: + Exception: catches generic exceptions to prevent one failure from stopping the batch. + """ + start_time = time.perf_counter() + + start_api_calls = get_api_call_count() + + logger.info(f"Processing Issue #{issue_number}...") + logger.debug(f"#{issue_number}: Initializing runner and session.") + + try: + runner = InMemoryRunner(agent=root_agent, app_name=APP_NAME) + session = await runner.session_service.create_session( + user_id=USER_ID, app_name=APP_NAME + ) + + prompt_text = f"Audit Issue #{issue_number}." + prompt_message = types.Content( + role="user", parts=[types.Part(text=prompt_text)] + ) + + logger.debug(f"#{issue_number}: Sending prompt to agent.") - prompt_text = ( - "Find and process all open issues to manage staleness according to your" - " rules." + async for event in runner.run_async( + user_id=USER_ID, session_id=session.id, new_message=prompt_message + ): + if ( + event.content + and event.content.parts + and hasattr(event.content.parts[0], "text") + ): + text = event.content.parts[0].text + if text: + clean_text = text[:150].replace("\n", " ") + logger.info(f"#{issue_number} Decision: {clean_text}...") + + except Exception as e: + logger.error(f"Error processing issue #{issue_number}: {e}", exc_info=True) + + duration = time.perf_counter() - start_time + + end_api_calls = get_api_call_count() + issue_api_calls = end_api_calls - start_api_calls + + logger.info( + f"Issue #{issue_number} finished in {duration:.2f}s " + f"with ~{issue_api_calls} API calls." ) - logger.info(f"Initial Agent Prompt: {prompt_text}\n") - prompt_message = types.Content( - role="user", parts=[types.Part(text=prompt_text)] + + return duration, issue_api_calls + + +async def main(): + """ + Main entry point to run the stale issue bot concurrently. + + Fetches old issues and processes them in batches to respect API rate limits + and concurrency constraints. + """ + logger.info(f"--- Starting Stale Bot for {OWNER}/{REPO} ---") + logger.info(f"Concurrency level set to {CONCURRENCY_LIMIT}") + + reset_api_call_count() + + filter_days = STALE_HOURS_THRESHOLD / 24 + logger.debug(f"Fetching issues older than {filter_days:.2f} days...") + + try: + all_issues = get_old_open_issue_numbers(OWNER, REPO, days_old=filter_days) + except Exception as e: + logger.critical(f"Failed to fetch issue list: {e}", exc_info=True) + return + + total_count = len(all_issues) + + search_api_calls = get_api_call_count() + + if total_count == 0: + logger.info("No issues matched the criteria. Run finished.") + return + + logger.info( + f"Found {total_count} issues to process. " + f"(Initial search used {search_api_calls} API calls)." ) - async for event in runner.run_async( - user_id=USER_ID, session_id=session.id, new_message=prompt_message - ): - if ( - event.content - and event.content.parts - and hasattr(event.content.parts[0], "text") - ): - # Print the agent's "thoughts" and actions for logging purposes - logger.debug(f"** {event.author} (ADK): {event.content.parts[0].text}") + total_processing_time = 0.0 + total_issue_api_calls = 0 + processed_count = 0 + + # Process the list in chunks of size CONCURRENCY_LIMIT + for i in range(0, total_count, CONCURRENCY_LIMIT): + chunk = all_issues[i : i + CONCURRENCY_LIMIT] + current_chunk_num = i // CONCURRENCY_LIMIT + 1 - logger.info(f"--- Stale Agent Run Finished---") + logger.info( + f"--- Starting chunk {current_chunk_num}: Processing issues {chunk} ---" + ) + + tasks = [process_single_issue(issue_num) for issue_num in chunk] + + results = await asyncio.gather(*tasks) + + for duration, api_calls in results: + total_processing_time += duration + total_issue_api_calls += api_calls + + processed_count += len(chunk) + logger.info( + f"--- Finished chunk {current_chunk_num}. Progress:" + f" {processed_count}/{total_count} ---" + ) + + if (i + CONCURRENCY_LIMIT) < total_count: + logger.debug( + f"Sleeping for {SLEEP_BETWEEN_CHUNKS}s to respect rate limits..." + ) + await asyncio.sleep(SLEEP_BETWEEN_CHUNKS) + + total_api_calls_for_run = search_api_calls + total_issue_api_calls + avg_time_per_issue = ( + total_processing_time / total_count if total_count > 0 else 0 + ) + + logger.info("--- Stale Agent Run Finished ---") + logger.info(f"Successfully processed {processed_count} issues.") + logger.info(f"Total API calls made this run: {total_api_calls_for_run}") + logger.info( + f"Average processing time per issue: {avg_time_per_issue:.2f} seconds." + ) if __name__ == "__main__": - start_time = time.time() - logger.info(f"Initializing stale agent for repository: {OWNER}/{REPO}") - logger.info("-" * 80) + start_time = time.perf_counter() - asyncio.run(main()) + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.warning("Bot execution interrupted manually.") + except Exception as e: + logger.critical(f"Unexpected fatal error: {e}", exc_info=True) - logger.info("-" * 80) - end_time = time.time() - duration = end_time - start_time - logger.info(f"Script finished in {duration:.2f} seconds.") + duration = time.perf_counter() - start_time + logger.info(f"Full audit finished in {duration/60:.2f} minutes.") diff --git a/contributing/samples/adk_stale_agent/settings.py b/contributing/samples/adk_stale_agent/settings.py index 1b71e451f3..599c6ef2ea 100644 --- a/contributing/samples/adk_stale_agent/settings.py +++ b/contributing/samples/adk_stale_agent/settings.py @@ -33,7 +33,6 @@ REQUEST_CLARIFICATION_LABEL = "request clarification" # --- THRESHOLDS IN HOURS --- -# These values can be overridden in a .env file for rapid testing (e.g., STALE_HOURS_THRESHOLD=1) # Default: 168 hours (7 days) # The number of hours of inactivity after a maintainer comment before an issue is marked as stale. STALE_HOURS_THRESHOLD = float(os.getenv("STALE_HOURS_THRESHOLD", 168)) @@ -44,6 +43,21 @@ os.getenv("CLOSE_HOURS_AFTER_STALE_THRESHOLD", 168) ) -# --- BATCH SIZE CONFIGURATION --- -# The maximum number of oldest open issues to process in a single run of the bot. -ISSUES_PER_RUN = int(os.getenv("ISSUES_PER_RUN", 100)) +# --- Performance Configuration --- +# The number of issues to process concurrently. +# Higher values are faster but increase the immediate rate of API calls +CONCURRENCY_LIMIT = int(os.getenv("CONCURRENCY_LIMIT", 3)) + +# --- GraphQL Query Limits --- +# The number of most recent comments to fetch for context analysis. +GRAPHQL_COMMENT_LIMIT = int(os.getenv("GRAPHQL_COMMENT_LIMIT", 30)) + +# The number of most recent description edits to fetch. +GRAPHQL_EDIT_LIMIT = int(os.getenv("GRAPHQL_EDIT_LIMIT", 10)) + +# The number of most recent timeline events (labels, renames, reopens) to fetch. +GRAPHQL_TIMELINE_LIMIT = int(os.getenv("GRAPHQL_TIMELINE_LIMIT", 20)) + +# --- Rate Limiting --- +# Time in seconds to wait between processing chunks. +SLEEP_BETWEEN_CHUNKS = float(os.getenv("SLEEP_BETWEEN_CHUNKS", 1.5)) diff --git a/contributing/samples/adk_stale_agent/utils.py b/contributing/samples/adk_stale_agent/utils.py index 0efb051f72..a396c22ac7 100644 --- a/contributing/samples/adk_stale_agent/utils.py +++ b/contributing/samples/adk_stale_agent/utils.py @@ -12,48 +12,249 @@ # See the License for the specific language governing permissions and # limitations under the License. +from datetime import datetime +from datetime import timedelta +from datetime import timezone +import logging +import threading from typing import Any +from typing import Dict +from typing import List +from typing import Optional from adk_stale_agent.settings import GITHUB_TOKEN +from adk_stale_agent.settings import STALE_HOURS_THRESHOLD +import dateutil.parser import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry +logger = logging.getLogger("google_adk." + __name__) + +# --- API Call Counter for Monitoring --- +_api_call_count = 0 +_counter_lock = threading.Lock() + + +def get_api_call_count() -> int: + """ + Returns the total number of API calls made since the last reset. + + Returns: + int: The global count of API calls. + """ + with _counter_lock: + return _api_call_count + + +def reset_api_call_count() -> None: + """Resets the global API call counter to zero.""" + global _api_call_count + with _counter_lock: + _api_call_count = 0 + + +def _increment_api_call_count() -> None: + """ + Atomically increments the global API call counter. + Required because the agent may run tools in parallel threads. + """ + global _api_call_count + with _counter_lock: + _api_call_count += 1 + + +# --- Production-Ready HTTP Session with Exponential Backoff --- + +# Configure the retry strategy: +retry_strategy = Retry( + total=6, + backoff_factor=2, + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods=[ + "HEAD", + "GET", + "POST", + "PUT", + "DELETE", + "OPTIONS", + "TRACE", + "PATCH", + ], +) + +adapter = HTTPAdapter(max_retries=retry_strategy) + +# Create a single, reusable Session object for connection pooling _session = requests.Session() +_session.mount("https://", adapter) +_session.mount("http://", adapter) + _session.headers.update({ "Authorization": f"token {GITHUB_TOKEN}", "Accept": "application/vnd.github.v3+json", }) -def get_request(url: str, params: dict[str, Any] | None = None) -> Any: - """Sends a GET request to the GitHub API.""" - response = _session.get(url, params=params or {}, timeout=60) - response.raise_for_status() - return response.json() +def get_request(url: str, params: Optional[Dict[str, Any]] = None) -> Any: + """ + Sends a GET request to the GitHub API with automatic retries. + + Args: + url (str): The URL endpoint. + params (Optional[Dict[str, Any]]): Query parameters. + + Returns: + Any: The JSON response parsed into a dict or list. + + Raises: + requests.exceptions.RequestException: If retries are exhausted. + """ + _increment_api_call_count() + try: + response = _session.get(url, params=params or {}, timeout=60) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + logger.error(f"GET request failed for {url}: {e}") + raise def post_request(url: str, payload: Any) -> Any: - """Sends a POST request to the GitHub API.""" - response = _session.post(url, json=payload, timeout=60) - response.raise_for_status() - return response.json() + """ + Sends a POST request to the GitHub API with automatic retries. + + Args: + url (str): The URL endpoint. + payload (Any): The JSON payload. + + Returns: + Any: The JSON response. + """ + _increment_api_call_count() + try: + response = _session.post(url, json=payload, timeout=60) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + logger.error(f"POST request failed for {url}: {e}") + raise def patch_request(url: str, payload: Any) -> Any: - """Sends a PATCH request to the GitHub API.""" - response = _session.patch(url, json=payload, timeout=60) - response.raise_for_status() - return response.json() + """ + Sends a PATCH request to the GitHub API with automatic retries. + + Args: + url (str): The URL endpoint. + payload (Any): The JSON payload. + + Returns: + Any: The JSON response. + """ + _increment_api_call_count() + try: + response = _session.patch(url, json=payload, timeout=60) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + logger.error(f"PATCH request failed for {url}: {e}") + raise def delete_request(url: str) -> Any: - """Sends a DELETE request to the GitHub API.""" - response = _session.delete(url, timeout=60) - response.raise_for_status() - if response.status_code == 204: - return {"status": "success"} - return response.json() + """ + Sends a DELETE request to the GitHub API with automatic retries. + Args: + url (str): The URL endpoint. -def error_response(error_message: str) -> dict[str, Any]: - """Creates a standardized error dictionary for the agent.""" + Returns: + Any: A success dict if 204, else the JSON response. + """ + _increment_api_call_count() + try: + response = _session.delete(url, timeout=60) + response.raise_for_status() + if response.status_code == 204: + return {"status": "success", "message": "Deletion successful."} + return response.json() + except requests.exceptions.RequestException as e: + logger.error(f"DELETE request failed for {url}: {e}") + raise + + +def error_response(error_message: str) -> Dict[str, Any]: + """ + Creates a standardized error response dictionary for tool outputs. + + Args: + error_message (str): The error details. + + Returns: + Dict[str, Any]: Standardized error object. + """ return {"status": "error", "message": error_message} + + +def get_old_open_issue_numbers( + owner: str, repo: str, days_old: Optional[float] = None +) -> List[int]: + """ + Finds open issues older than the specified threshold using server-side filtering. + + OPTIMIZATION: + Instead of fetching ALL issues and filtering in Python (which wastes API calls), + this uses the GitHub Search API `created: Date: Mon, 1 Dec 2025 13:15:38 -0800 Subject: [PATCH 48/63] chore: Move simulation related modules to a sub-package in evaluation Co-authored-by: Keyur Joshi PiperOrigin-RevId: 838904075 --- src/google/adk/cli/cli_tools_click.py | 2 +- src/google/adk/evaluation/agent_evaluator.py | 2 +- src/google/adk/evaluation/eval_config.py | 2 +- .../adk/evaluation/evaluation_generator.py | 6 +++--- src/google/adk/evaluation/local_eval_service.py | 2 +- src/google/adk/evaluation/simulation/__init__.py | 13 +++++++++++++ .../llm_backed_user_simulator.py | 16 ++++++++-------- .../{ => simulation}/static_user_simulator.py | 8 ++++---- .../{ => simulation}/user_simulator.py | 8 ++++---- .../{ => simulation}/user_simulator_provider.py | 4 ++-- .../unittests/evaluation/simulation/__init__.py | 13 +++++++++++++ .../test_llm_backed_user_simulator.py | 8 ++++---- .../test_static_user_simulator.py | 4 ++-- .../{ => simulation}/test_user_simulator.py | 4 ++-- .../test_user_simulator_provider.py | 10 +++++----- .../evaluation/test_evaluation_generator.py | 6 +++--- 16 files changed, 67 insertions(+), 41 deletions(-) create mode 100644 src/google/adk/evaluation/simulation/__init__.py rename src/google/adk/evaluation/{ => simulation}/llm_backed_user_simulator.py (96%) rename src/google/adk/evaluation/{ => simulation}/static_user_simulator.py (93%) rename src/google/adk/evaluation/{ => simulation}/user_simulator.py (95%) rename src/google/adk/evaluation/{ => simulation}/user_simulator_provider.py (97%) create mode 100644 tests/unittests/evaluation/simulation/__init__.py rename tests/unittests/evaluation/{ => simulation}/test_llm_backed_user_simulator.py (95%) rename tests/unittests/evaluation/{ => simulation}/test_static_user_simulator.py (93%) rename tests/unittests/evaluation/{ => simulation}/test_user_simulator.py (90%) rename tests/unittests/evaluation/{ => simulation}/test_user_simulator_provider.py (86%) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index e519427259..6c3e7b98a9 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -627,7 +627,7 @@ def cli_eval( from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import load_eval_set_from_file from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager - from ..evaluation.user_simulator_provider import UserSimulatorProvider + from ..evaluation.simulation.user_simulator_provider import UserSimulatorProvider from .cli_eval import _collect_eval_results from .cli_eval import _collect_inferences from .cli_eval import get_root_agent diff --git a/src/google/adk/evaluation/agent_evaluator.py b/src/google/adk/evaluation/agent_evaluator.py index cafa712f56..514681bfa9 100644 --- a/src/google/adk/evaluation/agent_evaluator.py +++ b/src/google/adk/evaluation/agent_evaluator.py @@ -50,7 +50,7 @@ from .evaluator import EvalStatus from .in_memory_eval_sets_manager import InMemoryEvalSetsManager from .local_eval_sets_manager import convert_eval_set_to_pydantic_schema -from .user_simulator_provider import UserSimulatorProvider +from .simulation.user_simulator_provider import UserSimulatorProvider logger = logging.getLogger("google_adk." + __name__) diff --git a/src/google/adk/evaluation/eval_config.py b/src/google/adk/evaluation/eval_config.py index d5b94af5e1..13b2e92274 100644 --- a/src/google/adk/evaluation/eval_config.py +++ b/src/google/adk/evaluation/eval_config.py @@ -27,7 +27,7 @@ from ..evaluation.eval_metrics import EvalMetric from .eval_metrics import BaseCriterion from .eval_metrics import Threshold -from .user_simulator import BaseUserSimulatorConfig +from .simulation.user_simulator import BaseUserSimulatorConfig logger = logging.getLogger("google_adk." + __name__) diff --git a/src/google/adk/evaluation/evaluation_generator.py b/src/google/adk/evaluation/evaluation_generator.py index e9c7dc5436..5d8b48c150 100644 --- a/src/google/adk/evaluation/evaluation_generator.py +++ b/src/google/adk/evaluation/evaluation_generator.py @@ -45,9 +45,9 @@ from .eval_case import SessionInput from .eval_set import EvalSet from .request_intercepter_plugin import _RequestIntercepterPlugin -from .user_simulator import Status as UserSimulatorStatus -from .user_simulator import UserSimulator -from .user_simulator_provider import UserSimulatorProvider +from .simulation.user_simulator import Status as UserSimulatorStatus +from .simulation.user_simulator import UserSimulator +from .simulation.user_simulator_provider import UserSimulatorProvider _USER_AUTHOR = "user" _DEFAULT_AUTHOR = "agent" diff --git a/src/google/adk/evaluation/local_eval_service.py b/src/google/adk/evaluation/local_eval_service.py index 30344702d2..f454266e00 100644 --- a/src/google/adk/evaluation/local_eval_service.py +++ b/src/google/adk/evaluation/local_eval_service.py @@ -55,7 +55,7 @@ from .evaluator import PerInvocationResult from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY from .metric_evaluator_registry import MetricEvaluatorRegistry -from .user_simulator_provider import UserSimulatorProvider +from .simulation.user_simulator_provider import UserSimulatorProvider logger = logging.getLogger('google_adk.' + __name__) diff --git a/src/google/adk/evaluation/simulation/__init__.py b/src/google/adk/evaluation/simulation/__init__.py new file mode 100644 index 0000000000..0a2669d7a2 --- /dev/null +++ b/src/google/adk/evaluation/simulation/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/google/adk/evaluation/llm_backed_user_simulator.py b/src/google/adk/evaluation/simulation/llm_backed_user_simulator.py similarity index 96% rename from src/google/adk/evaluation/llm_backed_user_simulator.py rename to src/google/adk/evaluation/simulation/llm_backed_user_simulator.py index 2fbfcc44d1..4af228772d 100644 --- a/src/google/adk/evaluation/llm_backed_user_simulator.py +++ b/src/google/adk/evaluation/simulation/llm_backed_user_simulator.py @@ -22,14 +22,14 @@ from pydantic import Field from typing_extensions import override -from ..events.event import Event -from ..models.llm_request import LlmRequest -from ..models.registry import LLMRegistry -from ..utils.context_utils import Aclosing -from ..utils.feature_decorator import experimental -from ._retry_options_utils import add_default_retry_options_if_not_present -from .conversation_scenarios import ConversationScenario -from .evaluator import Evaluator +from ...events.event import Event +from ...models.llm_request import LlmRequest +from ...models.registry import LLMRegistry +from ...utils.context_utils import Aclosing +from ...utils.feature_decorator import experimental +from .._retry_options_utils import add_default_retry_options_if_not_present +from ..conversation_scenarios import ConversationScenario +from ..evaluator import Evaluator from .user_simulator import BaseUserSimulatorConfig from .user_simulator import NextUserMessage from .user_simulator import Status diff --git a/src/google/adk/evaluation/static_user_simulator.py b/src/google/adk/evaluation/simulation/static_user_simulator.py similarity index 93% rename from src/google/adk/evaluation/static_user_simulator.py rename to src/google/adk/evaluation/simulation/static_user_simulator.py index 4c5e2cb54d..e1de18a706 100644 --- a/src/google/adk/evaluation/static_user_simulator.py +++ b/src/google/adk/evaluation/simulation/static_user_simulator.py @@ -19,10 +19,10 @@ from typing_extensions import override -from ..events.event import Event -from ..utils.feature_decorator import experimental -from .eval_case import StaticConversation -from .evaluator import Evaluator +from ...events.event import Event +from ...utils.feature_decorator import experimental +from ..eval_case import StaticConversation +from ..evaluator import Evaluator from .user_simulator import BaseUserSimulatorConfig from .user_simulator import NextUserMessage from .user_simulator import Status diff --git a/src/google/adk/evaluation/user_simulator.py b/src/google/adk/evaluation/simulation/user_simulator.py similarity index 95% rename from src/google/adk/evaluation/user_simulator.py rename to src/google/adk/evaluation/simulation/user_simulator.py index c5ab013d7c..57656b76de 100644 --- a/src/google/adk/evaluation/user_simulator.py +++ b/src/google/adk/evaluation/simulation/user_simulator.py @@ -26,10 +26,10 @@ from pydantic import model_validator from pydantic import ValidationError -from ..events.event import Event -from ..utils.feature_decorator import experimental -from .common import EvalBaseModel -from .evaluator import Evaluator +from ...events.event import Event +from ...utils.feature_decorator import experimental +from ..common import EvalBaseModel +from ..evaluator import Evaluator class BaseUserSimulatorConfig(BaseModel): diff --git a/src/google/adk/evaluation/user_simulator_provider.py b/src/google/adk/evaluation/simulation/user_simulator_provider.py similarity index 97% rename from src/google/adk/evaluation/user_simulator_provider.py rename to src/google/adk/evaluation/simulation/user_simulator_provider.py index 1aea8c8c92..b1bfd3226c 100644 --- a/src/google/adk/evaluation/user_simulator_provider.py +++ b/src/google/adk/evaluation/simulation/user_simulator_provider.py @@ -16,8 +16,8 @@ from typing import Optional -from ..utils.feature_decorator import experimental -from .eval_case import EvalCase +from ...utils.feature_decorator import experimental +from ..eval_case import EvalCase from .llm_backed_user_simulator import LlmBackedUserSimulator from .static_user_simulator import StaticUserSimulator from .user_simulator import BaseUserSimulatorConfig diff --git a/tests/unittests/evaluation/simulation/__init__.py b/tests/unittests/evaluation/simulation/__init__.py new file mode 100644 index 0000000000..0a2669d7a2 --- /dev/null +++ b/tests/unittests/evaluation/simulation/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/evaluation/test_llm_backed_user_simulator.py b/tests/unittests/evaluation/simulation/test_llm_backed_user_simulator.py similarity index 95% rename from tests/unittests/evaluation/test_llm_backed_user_simulator.py rename to tests/unittests/evaluation/simulation/test_llm_backed_user_simulator.py index 6ef3969e70..75db778bc7 100644 --- a/tests/unittests/evaluation/test_llm_backed_user_simulator.py +++ b/tests/unittests/evaluation/simulation/test_llm_backed_user_simulator.py @@ -15,9 +15,9 @@ from __future__ import annotations from google.adk.evaluation import conversation_scenarios -from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulator -from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig -from google.adk.evaluation.user_simulator import Status +from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulator +from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig +from google.adk.evaluation.simulation.user_simulator import Status from google.adk.events.event import Event from google.genai import types import pytest @@ -112,7 +112,7 @@ async def to_async_iter(items): def mock_llm_agent(mocker): """Provides a mock LLM agent.""" mock_llm_registry_cls = mocker.patch( - "google.adk.evaluation.llm_backed_user_simulator.LLMRegistry" + "google.adk.evaluation.simulation.llm_backed_user_simulator.LLMRegistry" ) mock_llm_registry = mocker.MagicMock() mock_llm_registry_cls.return_value = mock_llm_registry diff --git a/tests/unittests/evaluation/test_static_user_simulator.py b/tests/unittests/evaluation/simulation/test_static_user_simulator.py similarity index 93% rename from tests/unittests/evaluation/test_static_user_simulator.py rename to tests/unittests/evaluation/simulation/test_static_user_simulator.py index 5cc70c80e6..f18c23f5f2 100644 --- a/tests/unittests/evaluation/test_static_user_simulator.py +++ b/tests/unittests/evaluation/simulation/test_static_user_simulator.py @@ -14,9 +14,9 @@ from __future__ import annotations -from google.adk.evaluation import static_user_simulator -from google.adk.evaluation import user_simulator from google.adk.evaluation.eval_case import Invocation +from google.adk.evaluation.simulation import static_user_simulator +from google.adk.evaluation.simulation import user_simulator from google.genai import types import pytest diff --git a/tests/unittests/evaluation/test_user_simulator.py b/tests/unittests/evaluation/simulation/test_user_simulator.py similarity index 90% rename from tests/unittests/evaluation/test_user_simulator.py rename to tests/unittests/evaluation/simulation/test_user_simulator.py index c3e1e606ee..dbe7aff1db 100644 --- a/tests/unittests/evaluation/test_user_simulator.py +++ b/tests/unittests/evaluation/simulation/test_user_simulator.py @@ -14,8 +14,8 @@ from __future__ import annotations -from google.adk.evaluation.user_simulator import NextUserMessage -from google.adk.evaluation.user_simulator import Status +from google.adk.evaluation.simulation.user_simulator import NextUserMessage +from google.adk.evaluation.simulation.user_simulator import Status from google.genai.types import Content import pytest diff --git a/tests/unittests/evaluation/test_user_simulator_provider.py b/tests/unittests/evaluation/simulation/test_user_simulator_provider.py similarity index 86% rename from tests/unittests/evaluation/test_user_simulator_provider.py rename to tests/unittests/evaluation/simulation/test_user_simulator_provider.py index 7cff4241b6..c4fb826fb7 100644 --- a/tests/unittests/evaluation/test_user_simulator_provider.py +++ b/tests/unittests/evaluation/simulation/test_user_simulator_provider.py @@ -16,10 +16,10 @@ from google.adk.evaluation import conversation_scenarios from google.adk.evaluation import eval_case -from google.adk.evaluation import user_simulator_provider -from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulator -from google.adk.evaluation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig -from google.adk.evaluation.static_user_simulator import StaticUserSimulator +from google.adk.evaluation.simulation import user_simulator_provider +from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulator +from google.adk.evaluation.simulation.llm_backed_user_simulator import LlmBackedUserSimulatorConfig +from google.adk.evaluation.simulation.static_user_simulator import StaticUserSimulator from google.genai import types import pytest @@ -52,7 +52,7 @@ def test_provide_static_user_simulator(self): def test_provide_llm_backed_user_simulator(self, mocker): """Tests the case when a LlmBackedUserSimulator should be provided.""" mock_llm_registry = mocker.patch( - 'google.adk.evaluation.llm_backed_user_simulator.LLMRegistry', + 'google.adk.evaluation.simulation.llm_backed_user_simulator.LLMRegistry', autospec=True, ) mock_llm_registry.return_value.resolve.return_value = mocker.Mock() diff --git a/tests/unittests/evaluation/test_evaluation_generator.py b/tests/unittests/evaluation/test_evaluation_generator.py index 27372f12c2..873239e7f4 100644 --- a/tests/unittests/evaluation/test_evaluation_generator.py +++ b/tests/unittests/evaluation/test_evaluation_generator.py @@ -18,9 +18,9 @@ from google.adk.evaluation.app_details import AppDetails from google.adk.evaluation.evaluation_generator import EvaluationGenerator from google.adk.evaluation.request_intercepter_plugin import _RequestIntercepterPlugin -from google.adk.evaluation.user_simulator import NextUserMessage -from google.adk.evaluation.user_simulator import Status as UserSimulatorStatus -from google.adk.evaluation.user_simulator import UserSimulator +from google.adk.evaluation.simulation.user_simulator import NextUserMessage +from google.adk.evaluation.simulation.user_simulator import Status as UserSimulatorStatus +from google.adk.evaluation.simulation.user_simulator import UserSimulator from google.adk.events.event import Event from google.adk.models.llm_request import LlmRequest from google.genai import types From ed9da3fa45cc0d2ee1a88446e8026519eda1134b Mon Sep 17 00:00:00 2001 From: George Weale Date: Mon, 1 Dec 2025 16:21:27 -0800 Subject: [PATCH 49/63] feat!: Introduction of ADK folder for local session and artifact storage Default CLI session storage to SQLite instead of in-memory Previously, adk run and adk web used in-memory session storage by default, causing sessions to be lost on restart. Now sessions persist to .adk/session.db automatically. To use in-memory storage, pass --session-service-uri memory:// Co-authored-by: George Weale PiperOrigin-RevId: 838975328 --- src/google/adk/cli/cli.py | 5 ++-- src/google/adk/cli/utils/local_storage.py | 26 +++++++++++++++++++ src/google/adk/cli/utils/service_factory.py | 7 +++-- .../cli/utils/test_service_factory.py | 14 +++++++--- 4 files changed, 42 insertions(+), 10 deletions(-) diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index af57a687fb..a1b63a4c46 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -26,12 +26,10 @@ from ..agents.llm_agent import LlmAgent from ..apps.app import App from ..artifacts.base_artifact_service import BaseArtifactService -from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..auth.credential_service.base_credential_service import BaseCredentialService from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..runners import Runner from ..sessions.base_session_service import BaseSessionService -from ..sessions.in_memory_session_service import InMemorySessionService from ..sessions.session import Session from ..utils.context_utils import Aclosing from ..utils.env_utils import is_env_enabled @@ -162,8 +160,9 @@ async def run_cli( user_id = 'test_user' # Create session and artifact services using factory functions + # Sessions persist under //.adk/session.db by default. session_service = create_session_service_from_options( - base_dir=agent_root, + base_dir=agent_parent_path, session_service_uri=session_service_uri, ) diff --git a/src/google/adk/cli/utils/local_storage.py b/src/google/adk/cli/utils/local_storage.py index b170d66531..ec7099b8c8 100644 --- a/src/google/adk/cli/utils/local_storage.py +++ b/src/google/adk/cli/utils/local_storage.py @@ -57,6 +57,32 @@ def create_local_database_session_service( return SqliteSessionService(db_path=str(session_db_path)) +def create_local_session_service( + *, + base_dir: Path | str, + per_agent: bool = False, +) -> BaseSessionService: + """Creates a local SQLite-backed session service. + + Args: + base_dir: The base directory for the agent(s). + per_agent: If True, creates a PerAgentDatabaseSessionService that stores + sessions in each agent's .adk folder. If False, creates a single + SqliteSessionService at base_dir/.adk/session.db. + + Returns: + A BaseSessionService instance backed by SQLite. + """ + if per_agent: + logger.info( + "Using per-agent session storage rooted at %s", + base_dir, + ) + return PerAgentDatabaseSessionService(agents_root=base_dir) + + return create_local_database_session_service(base_dir=base_dir) + + def create_local_artifact_service( *, base_dir: Path | str ) -> BaseArtifactService: diff --git a/src/google/adk/cli/utils/service_factory.py b/src/google/adk/cli/utils/service_factory.py index fc2a642c4f..60f4ddd3cf 100644 --- a/src/google/adk/cli/utils/service_factory.py +++ b/src/google/adk/cli/utils/service_factory.py @@ -23,6 +23,7 @@ from ...sessions.base_session_service import BaseSessionService from ..service_registry import get_service_registry from .local_storage import create_local_artifact_service +from .local_storage import create_local_session_service logger = logging.getLogger("google_adk." + __name__) @@ -62,10 +63,8 @@ def create_session_service_from_options( ) return DatabaseSessionService(db_url=session_service_uri, **fallback_kwargs) - logger.info("Using in-memory session service") - from ...sessions.in_memory_session_service import InMemorySessionService - - return InMemorySessionService() + # Default to per-agent local SQLite storage in //.adk/. + return create_local_session_service(base_dir=base_path, per_agent=True) def create_memory_service_from_options( diff --git a/tests/unittests/cli/utils/test_service_factory.py b/tests/unittests/cli/utils/test_service_factory.py index 207c96642a..9d9afdd23b 100644 --- a/tests/unittests/cli/utils/test_service_factory.py +++ b/tests/unittests/cli/utils/test_service_factory.py @@ -19,10 +19,10 @@ from pathlib import Path from unittest.mock import Mock +from google.adk.cli.utils.local_storage import PerAgentDatabaseSessionService import google.adk.cli.utils.service_factory as service_factory from google.adk.memory.in_memory_memory_service import InMemoryMemoryService from google.adk.sessions.database_session_service import DatabaseSessionService -from google.adk.sessions.in_memory_session_service import InMemorySessionService import pytest @@ -44,12 +44,20 @@ def test_create_session_service_uses_registry(tmp_path: Path, monkeypatch): ) -def test_create_session_service_defaults_to_memory(tmp_path: Path): +@pytest.mark.asyncio +async def test_create_session_service_defaults_to_per_agent_sqlite( + tmp_path: Path, +) -> None: + agent_dir = tmp_path / "agent_a" + agent_dir.mkdir() service = service_factory.create_session_service_from_options( base_dir=tmp_path, ) - assert isinstance(service, InMemorySessionService) + assert isinstance(service, PerAgentDatabaseSessionService) + session = await service.create_session(app_name="agent_a", user_id="user") + assert session.app_name == "agent_a" + assert (agent_dir / ".adk" / "session.db").exists() def test_create_session_service_fallbacks_to_database( From 7e8eeca6aa7cb81f11fd7a0e89fe7dd02b95c88d Mon Sep 17 00:00:00 2001 From: George Weale Date: Mon, 1 Dec 2025 16:28:45 -0800 Subject: [PATCH 50/63] fix: Add a FastAPI endpoint for saving artifacts This change adds new `POST` endpoint `/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts` to the ADK web server. This endpoint lets clients to save new artifacts associated with a specific session. The endpoint uses `SaveArtifactRequest` and returns `SaveArtifactResponse`, including the version and canonical URI of the saved artifact. Close #1975 Co-authored-by: George Weale PiperOrigin-RevId: 838977880 --- .../adk/artifacts/file_artifact_service.py | 13 +- .../adk/artifacts/gcs_artifact_service.py | 7 +- .../artifacts/in_memory_artifact_service.py | 11 +- src/google/adk/cli/adk_web_server.py | 62 +++++ src/google/adk/cli/fast_api.py | 82 +++++-- .../adk/errors/input_validation_error.py | 28 +++ .../artifacts/test_artifact_service.py | 5 +- tests/unittests/cli/test_fast_api.py | 212 ++++++++++++++++-- 8 files changed, 363 insertions(+), 57 deletions(-) create mode 100644 src/google/adk/errors/input_validation_error.py diff --git a/src/google/adk/artifacts/file_artifact_service.py b/src/google/adk/artifacts/file_artifact_service.py index 825e4a7a71..53a830c066 100644 --- a/src/google/adk/artifacts/file_artifact_service.py +++ b/src/google/adk/artifacts/file_artifact_service.py @@ -32,6 +32,7 @@ from pydantic import ValidationError from typing_extensions import override +from ..errors.input_validation_error import InputValidationError from .base_artifact_service import ArtifactVersion from .base_artifact_service import BaseArtifactService @@ -100,14 +101,14 @@ def _resolve_scoped_artifact_path( to `scope_root`. Raises: - ValueError: If `filename` resolves outside of `scope_root`. + InputValidationError: If `filename` resolves outside of `scope_root`. """ stripped = _strip_user_namespace(filename).strip() pure_path = _to_posix_path(stripped) scope_root_resolved = scope_root.resolve(strict=False) if pure_path.is_absolute(): - raise ValueError( + raise InputValidationError( f"Absolute artifact filename {filename!r} is not permitted; " "provide a path relative to the storage scope." ) @@ -118,7 +119,7 @@ def _resolve_scoped_artifact_path( try: relative = candidate.relative_to(scope_root_resolved) except ValueError as exc: - raise ValueError( + raise InputValidationError( f"Artifact filename {filename!r} escapes storage directory " f"{scope_root_resolved}" ) from exc @@ -230,7 +231,7 @@ def _scope_root( if _is_user_scoped(session_id, filename): return _user_artifacts_dir(base) if not session_id: - raise ValueError( + raise InputValidationError( "Session ID must be provided for session-scoped artifacts." ) return _session_artifacts_dir(base, session_id) @@ -371,7 +372,9 @@ def _save_artifact_sync( content_path.write_text(artifact.text, encoding="utf-8") mime_type = None else: - raise ValueError("Artifact must have either inline_data or text content.") + raise InputValidationError( + "Artifact must have either inline_data or text content." + ) canonical_uri = self._canonical_uri( user_id=user_id, diff --git a/src/google/adk/artifacts/gcs_artifact_service.py b/src/google/adk/artifacts/gcs_artifact_service.py index fc18dab6fc..2bf713a5e8 100644 --- a/src/google/adk/artifacts/gcs_artifact_service.py +++ b/src/google/adk/artifacts/gcs_artifact_service.py @@ -30,6 +30,7 @@ from google.genai import types from typing_extensions import override +from ..errors.input_validation_error import InputValidationError from .base_artifact_service import ArtifactVersion from .base_artifact_service import BaseArtifactService @@ -161,7 +162,7 @@ def _get_blob_prefix( return f"{app_name}/{user_id}/user/{filename}" if session_id is None: - raise ValueError( + raise InputValidationError( "Session ID must be provided for session-scoped artifacts." ) return f"{app_name}/{user_id}/{session_id}/{filename}" @@ -230,7 +231,9 @@ def _save_artifact( " GcsArtifactService." ) else: - raise ValueError("Artifact must have either inline_data or text.") + raise InputValidationError( + "Artifact must have either inline_data or text." + ) return version diff --git a/src/google/adk/artifacts/in_memory_artifact_service.py b/src/google/adk/artifacts/in_memory_artifact_service.py index 246e8a85fb..2c7dd14127 100644 --- a/src/google/adk/artifacts/in_memory_artifact_service.py +++ b/src/google/adk/artifacts/in_memory_artifact_service.py @@ -18,12 +18,13 @@ from typing import Any from typing import Optional -from google.adk.artifacts import artifact_util from google.genai import types from pydantic import BaseModel from pydantic import Field from typing_extensions import override +from . import artifact_util +from ..errors.input_validation_error import InputValidationError from .base_artifact_service import ArtifactVersion from .base_artifact_service import BaseArtifactService @@ -86,7 +87,7 @@ def _artifact_path( return f"{app_name}/{user_id}/user/{filename}" if session_id is None: - raise ValueError( + raise InputValidationError( "Session ID must be provided for session-scoped artifacts." ) return f"{app_name}/{user_id}/{session_id}/{filename}" @@ -125,7 +126,7 @@ async def save_artifact( elif artifact.file_data is not None: if artifact_util.is_artifact_ref(artifact): if not artifact_util.parse_artifact_uri(artifact.file_data.file_uri): - raise ValueError( + raise InputValidationError( f"Invalid artifact reference URI: {artifact.file_data.file_uri}" ) # If it's a valid artifact URI, we store the artifact part as-is. @@ -133,7 +134,7 @@ async def save_artifact( else: artifact_version.mime_type = artifact.file_data.mime_type else: - raise ValueError("Not supported artifact type.") + raise InputValidationError("Not supported artifact type.") self.artifacts[path].append( _ArtifactEntry(data=artifact, artifact_version=artifact_version) @@ -172,7 +173,7 @@ async def load_artifact( artifact_data.file_data.file_uri ) if not parsed_uri: - raise ValueError( + raise InputValidationError( "Invalid artifact reference URI:" f" {artifact_data.file_data.file_uri}" ) diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 45747a52a1..78fe426628 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -61,9 +61,11 @@ from ..agents.run_config import RunConfig from ..agents.run_config import StreamingMode from ..apps.app import App +from ..artifacts.base_artifact_service import ArtifactVersion from ..artifacts.base_artifact_service import BaseArtifactService from ..auth.credential_service.base_credential_service import BaseCredentialService from ..errors.already_exists_error import AlreadyExistsError +from ..errors.input_validation_error import InputValidationError from ..errors.not_found_error import NotFoundError from ..evaluation.base_eval_service import InferenceConfig from ..evaluation.base_eval_service import InferenceRequest @@ -194,6 +196,19 @@ class CreateSessionRequest(common.BaseModel): ) +class SaveArtifactRequest(common.BaseModel): + """Request payload for saving a new artifact.""" + + filename: str = Field(description="Artifact filename.") + artifact: types.Part = Field( + description="Artifact payload encoded as google.genai.types.Part." + ) + custom_metadata: Optional[dict[str, Any]] = Field( + default=None, + description="Optional metadata to associate with the artifact version.", + ) + + class AddSessionToEvalSetRequest(common.BaseModel): eval_id: str session_id: str @@ -1316,6 +1331,53 @@ async def load_artifact_version( raise HTTPException(status_code=404, detail="Artifact not found") return artifact + @app.post( + "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", + response_model=ArtifactVersion, + response_model_exclude_none=True, + ) + async def save_artifact( + app_name: str, + user_id: str, + session_id: str, + req: SaveArtifactRequest, + ) -> ArtifactVersion: + try: + version = await self.artifact_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=req.filename, + artifact=req.artifact, + custom_metadata=req.custom_metadata, + ) + except InputValidationError as ive: + raise HTTPException(status_code=400, detail=str(ive)) from ive + except Exception as exc: # pylint: disable=broad-exception-caught + logger.error( + "Internal error while saving artifact %s for app=%s user=%s" + " session=%s: %s", + req.filename, + app_name, + user_id, + session_id, + exc, + exc_info=True, + ) + raise HTTPException(status_code=500, detail=str(exc)) from exc + artifact_version = await self.artifact_service.get_artifact_version( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=req.filename, + version=version, + ) + if artifact_version is None: + raise HTTPException( + status_code=500, detail="Artifact metadata unavailable" + ) + return artifact_version + @app.get( "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts", response_model_exclude_none=True, diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index c095b03a30..f9170968fd 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -14,6 +14,7 @@ from __future__ import annotations +import importlib import json import logging import os @@ -34,22 +35,43 @@ from starlette.types import Lifespan from watchdog.observers import Observer +from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager +from ..memory.in_memory_memory_service import InMemoryMemoryService from ..runners import Runner +from ..sessions.in_memory_session_service import InMemorySessionService from .adk_web_server import AdkWebServer +from .service_registry import get_service_registry from .service_registry import load_services_module from .utils import envs from .utils import evals from .utils.agent_change_handler import AgentChangeEventHandler from .utils.agent_loader import AgentLoader -from .utils.service_factory import create_artifact_service_from_options -from .utils.service_factory import create_memory_service_from_options -from .utils.service_factory import create_session_service_from_options logger = logging.getLogger("google_adk." + __name__) +_LAZY_SERVICE_IMPORTS: dict[str, str] = { + "AgentLoader": ".utils.agent_loader", + "InMemoryArtifactService": "..artifacts.in_memory_artifact_service", + "InMemoryMemoryService": "..memory.in_memory_memory_service", + "InMemorySessionService": "..sessions.in_memory_session_service", + "LocalEvalSetResultsManager": "..evaluation.local_eval_set_results_manager", + "LocalEvalSetsManager": "..evaluation.local_eval_sets_manager", +} + + +def __getattr__(name: str): + """Lazily import defaults so patching in tests keeps working.""" + if name not in _LAZY_SERVICE_IMPORTS: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + module = importlib.import_module(_LAZY_SERVICE_IMPORTS[name], __package__) + attr = getattr(module, name) + globals()[name] = attr + return attr + def get_fast_api_app( *, @@ -73,8 +95,6 @@ def get_fast_api_app( logo_text: Optional[str] = None, logo_image_url: Optional[str] = None, ) -> FastAPI: - # Convert to absolute path for consistency - agents_dir = str(Path(agents_dir).resolve()) # Set up eval managers. if eval_storage_uri: @@ -92,30 +112,48 @@ def get_fast_api_app( # Load services.py from agents_dir for custom service registration. load_services_module(agents_dir) + service_registry = get_service_registry() + # Build the Memory service - try: - memory_service = create_memory_service_from_options( - base_dir=agents_dir, - memory_service_uri=memory_service_uri, + if memory_service_uri: + memory_service = service_registry.create_memory_service( + memory_service_uri, agents_dir=agents_dir ) - except ValueError as exc: - raise click.ClickException(str(exc)) from exc + if not memory_service: + raise click.ClickException( + "Unsupported memory service URI: %s" % memory_service_uri + ) + else: + memory_service = InMemoryMemoryService() # Build the Session service - session_service = create_session_service_from_options( - base_dir=agents_dir, - session_service_uri=session_service_uri, - session_db_kwargs=session_db_kwargs, - ) + if session_service_uri: + session_kwargs = session_db_kwargs or {} + session_service = service_registry.create_session_service( + session_service_uri, agents_dir=agents_dir, **session_kwargs + ) + if not session_service: + # Fallback to DatabaseSessionService if the service registry doesn't + # support the session service URI scheme. + from ..sessions.database_session_service import DatabaseSessionService + + session_service = DatabaseSessionService( + db_url=session_service_uri, **session_kwargs + ) + else: + session_service = InMemorySessionService() # Build the Artifact service - try: - artifact_service = create_artifact_service_from_options( - base_dir=agents_dir, - artifact_service_uri=artifact_service_uri, + if artifact_service_uri: + artifact_service = service_registry.create_artifact_service( + artifact_service_uri, agents_dir=agents_dir ) - except ValueError as exc: - raise click.ClickException(str(exc)) from exc + if not artifact_service: + raise click.ClickException( + "Unsupported artifact service URI: %s" % artifact_service_uri + ) + else: + artifact_service = InMemoryArtifactService() # Build the Credential service credential_service = InMemoryCredentialService() diff --git a/src/google/adk/errors/input_validation_error.py b/src/google/adk/errors/input_validation_error.py new file mode 100644 index 0000000000..76b1625a10 --- /dev/null +++ b/src/google/adk/errors/input_validation_error.py @@ -0,0 +1,28 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + + +class InputValidationError(ValueError): + """Represents an error raised when user input fails validation.""" + + def __init__(self, message="Invalid input."): + """Initializes the InputValidationError exception. + + Args: + message (str): A message describing why the input is invalid. + """ + self.message = message + super().__init__(self.message) diff --git a/tests/unittests/artifacts/test_artifact_service.py b/tests/unittests/artifacts/test_artifact_service.py index f7b457f73b..c68ad512c0 100644 --- a/tests/unittests/artifacts/test_artifact_service.py +++ b/tests/unittests/artifacts/test_artifact_service.py @@ -32,6 +32,7 @@ from google.adk.artifacts.file_artifact_service import FileArtifactService from google.adk.artifacts.gcs_artifact_service import GcsArtifactService from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.errors.input_validation_error import InputValidationError from google.genai import types import pytest @@ -732,7 +733,7 @@ async def test_file_save_artifact_rejects_out_of_scope_paths( """FileArtifactService prevents path traversal outside of its storage roots.""" artifact_service = FileArtifactService(root_dir=tmp_path / "artifacts") part = types.Part(text="content") - with pytest.raises(ValueError): + with pytest.raises(InputValidationError): await artifact_service.save_artifact( app_name="myapp", user_id="user123", @@ -757,7 +758,7 @@ async def test_file_save_artifact_rejects_absolute_path_within_scope(tmp_path): / "diagram.png" ) part = types.Part(text="content") - with pytest.raises(ValueError): + with pytest.raises(InputValidationError): await artifact_service.save_artifact( app_name="myapp", user_id="user123", diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index a8b1ef2f2f..1fe04732f5 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -30,7 +30,9 @@ from google.adk.agents.base_agent import BaseAgent from google.adk.agents.run_config import RunConfig from google.adk.apps.app import App +from google.adk.artifacts.base_artifact_service import ArtifactVersion from google.adk.cli.fast_api import get_fast_api_app +from google.adk.errors.input_validation_error import InputValidationError from google.adk.evaluation.eval_case import EvalCase from google.adk.evaluation.eval_case import Invocation from google.adk.evaluation.eval_result import EvalSetResult @@ -211,48 +213,135 @@ def mock_session_service(): def mock_artifact_service(): """Create a mock artifact service.""" - # Storage for artifacts - artifacts = {} + artifacts: dict[str, list[dict[str, Any]]] = {} + + def _artifact_key( + app_name: str, user_id: str, session_id: Optional[str], filename: str + ) -> str: + if session_id is None: + return f"{app_name}:{user_id}:user:{filename}" + return f"{app_name}:{user_id}:{session_id}:{filename}" + + def _canonical_uri( + app_name: str, + user_id: str, + session_id: Optional[str], + filename: str, + version: int, + ) -> str: + if session_id is None: + return ( + f"artifact://apps/{app_name}/users/{user_id}/artifacts/" + f"{filename}/versions/{version}" + ) + return ( + f"artifact://apps/{app_name}/users/{user_id}/sessions/{session_id}/" + f"artifacts/{filename}/versions/{version}" + ) class MockArtifactService: + def __init__(self): + self._artifacts = artifacts + self.save_artifact_side_effect: Optional[BaseException] = None + + async def save_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + artifact: types.Part, + session_id: Optional[str] = None, + custom_metadata: Optional[dict[str, Any]] = None, + ) -> int: + if self.save_artifact_side_effect is not None: + effect = self.save_artifact_side_effect + if isinstance(effect, BaseException): + raise effect + raise TypeError( + "save_artifact_side_effect must be an exception instance." + ) + key = _artifact_key(app_name, user_id, session_id, filename) + entries = artifacts.setdefault(key, []) + version = len(entries) + artifact_version = ArtifactVersion( + version=version, + canonical_uri=_canonical_uri( + app_name, user_id, session_id, filename, version + ), + custom_metadata=custom_metadata or {}, + ) + if artifact.inline_data is not None: + artifact_version.mime_type = artifact.inline_data.mime_type + elif artifact.text is not None: + artifact_version.mime_type = "text/plain" + elif artifact.file_data is not None: + artifact_version.mime_type = artifact.file_data.mime_type + + entries.append({ + "version": version, + "artifact": artifact, + "metadata": artifact_version, + }) + return version + async def load_artifact( self, app_name, user_id, session_id, filename, version=None ): """Load an artifact by filename.""" - key = f"{app_name}:{user_id}:{session_id}:{filename}" + key = _artifact_key(app_name, user_id, session_id, filename) if key not in artifacts: return None if version is not None: - # Get a specific version - for v in artifacts[key]: - if v["version"] == version: - return v["artifact"] + for entry in artifacts[key]: + if entry["version"] == version: + return entry["artifact"] return None - # Get the latest version - return sorted(artifacts[key], key=lambda x: x["version"])[-1]["artifact"] + return artifacts[key][-1]["artifact"] async def list_artifact_keys(self, app_name, user_id, session_id): """List artifact names for a session.""" prefix = f"{app_name}:{user_id}:{session_id}:" return [ - k.split(":")[-1] for k in artifacts.keys() if k.startswith(prefix) + key.split(":")[-1] + for key in artifacts.keys() + if key.startswith(prefix) ] async def list_versions(self, app_name, user_id, session_id, filename): """List versions of an artifact.""" - key = f"{app_name}:{user_id}:{session_id}:{filename}" + key = _artifact_key(app_name, user_id, session_id, filename) if key not in artifacts: return [] - return [a["version"] for a in artifacts[key]] + return [entry["version"] for entry in artifacts[key]] async def delete_artifact(self, app_name, user_id, session_id, filename): """Delete an artifact.""" - key = f"{app_name}:{user_id}:{session_id}:{filename}" - if key in artifacts: - del artifacts[key] + key = _artifact_key(app_name, user_id, session_id, filename) + artifacts.pop(key, None) + + async def get_artifact_version( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: Optional[str] = None, + version: Optional[int] = None, + ) -> Optional[ArtifactVersion]: + key = _artifact_key(app_name, user_id, session_id, filename) + entries = artifacts.get(key) + if not entries: + return None + if version is None: + return entries[-1]["metadata"] + for entry in entries: + if entry["version"] == version: + return entry["metadata"] + return None return MockArtifactService() @@ -327,15 +416,15 @@ def test_app( with ( patch("signal.signal", return_value=None), patch( - "google.adk.cli.fast_api.create_session_service_from_options", + "google.adk.cli.fast_api.InMemorySessionService", return_value=mock_session_service, ), patch( - "google.adk.cli.fast_api.create_artifact_service_from_options", + "google.adk.cli.fast_api.InMemoryArtifactService", return_value=mock_artifact_service, ), patch( - "google.adk.cli.fast_api.create_memory_service_from_options", + "google.adk.cli.fast_api.InMemoryMemoryService", return_value=mock_memory_service, ), patch( @@ -472,15 +561,15 @@ def test_app_with_a2a( with ( patch("signal.signal", return_value=None), patch( - "google.adk.cli.fast_api.create_session_service_from_options", + "google.adk.cli.fast_api.InMemorySessionService", return_value=mock_session_service, ), patch( - "google.adk.cli.fast_api.create_artifact_service_from_options", + "google.adk.cli.fast_api.InMemoryArtifactService", return_value=mock_artifact_service, ), patch( - "google.adk.cli.fast_api.create_memory_service_from_options", + "google.adk.cli.fast_api.InMemoryMemoryService", return_value=mock_memory_service, ), patch( @@ -810,6 +899,87 @@ def test_list_artifact_names(test_app, create_test_session): logger.info(f"Listed {len(data)} artifacts") +def test_save_artifact(test_app, create_test_session, mock_artifact_service): + """Test saving an artifact through the FastAPI endpoint.""" + info = create_test_session + url = ( + f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/" + f"{info['session_id']}/artifacts" + ) + artifact_part = types.Part(text="hello world") + payload = { + "filename": "greeting.txt", + "artifact": artifact_part.model_dump(by_alias=True, exclude_none=True), + } + + response = test_app.post(url, json=payload) + assert response.status_code == 200 + data = response.json() + assert data["version"] == 0 + assert data["customMetadata"] == {} + assert data["mimeType"] in (None, "text/plain") + assert data["canonicalUri"].endswith( + f"/sessions/{info['session_id']}/artifacts/" + f"{payload['filename']}/versions/0" + ) + assert isinstance(data["createTime"], float) + + key = ( + f"{info['app_name']}:{info['user_id']}:{info['session_id']}:" + f"{payload['filename']}" + ) + stored = mock_artifact_service._artifacts[key][0] + assert stored["artifact"].text == "hello world" + + +def test_save_artifact_returns_400_on_validation_error( + test_app, create_test_session, mock_artifact_service +): + """Test save artifact endpoint surfaces validation errors as HTTP 400.""" + info = create_test_session + url = ( + f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/" + f"{info['session_id']}/artifacts" + ) + artifact_part = types.Part(text="bad data") + payload = { + "filename": "invalid.txt", + "artifact": artifact_part.model_dump(by_alias=True, exclude_none=True), + } + + mock_artifact_service.save_artifact_side_effect = InputValidationError( + "invalid artifact" + ) + + response = test_app.post(url, json=payload) + assert response.status_code == 400 + assert response.json()["detail"] == "invalid artifact" + + +def test_save_artifact_returns_500_on_unexpected_error( + test_app, create_test_session, mock_artifact_service +): + """Test save artifact endpoint surfaces unexpected errors as HTTP 500.""" + info = create_test_session + url = ( + f"/apps/{info['app_name']}/users/{info['user_id']}/sessions/" + f"{info['session_id']}/artifacts" + ) + artifact_part = types.Part(text="bad data") + payload = { + "filename": "invalid.txt", + "artifact": artifact_part.model_dump(by_alias=True, exclude_none=True), + } + + mock_artifact_service.save_artifact_side_effect = RuntimeError( + "unexpected failure" + ) + + response = test_app.post(url, json=payload) + assert response.status_code == 500 + assert response.json()["detail"] == "unexpected failure" + + def test_create_eval_set(test_app, test_session_info): """Test creating an eval set.""" url = f"/apps/{test_session_info['app_name']}/eval_sets/test_eval_set_id" From 98d82935e683115a6f4a0fdb307dadcbeddc0619 Mon Sep 17 00:00:00 2001 From: George Weale Date: Mon, 1 Dec 2025 16:30:49 -0800 Subject: [PATCH 51/63] fix: allow LlmAgent model to be provided via CodeConfig MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit LlmAgentConfig.model now accepts either a plain model string or a CodeConfig. This lets YAML configs pass a LiteLLM instance with managed API settings (e.g., api_base and fallbacks) so agents can hit KimiK2’s managed endpoint instead of only the default modelID. Close #3579 Co-authored-by: George Weale PiperOrigin-RevId: 838978654 --- src/google/adk/agents/llm_agent.py | 9 +++- src/google/adk/agents/llm_agent_config.py | 43 +++++++++++++++++- tests/unittests/agents/test_agent_config.py | 48 +++++++++++++++++++++ 3 files changed, 96 insertions(+), 4 deletions(-) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 71a074881c..005d073cc7 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -469,7 +469,10 @@ async def _run_async_impl( if ctx.is_resumable: events = ctx._get_events(current_invocation=True, current_branch=True) - if any(ctx.should_pause_invocation(e) for e in events[-2:]): + if events and ( + ctx.should_pause_invocation(events[-1]) + or ctx.should_pause_invocation(events[-2]) + ): return # Only yield an end state if the last event is no longer a long running # tool call. @@ -907,7 +910,9 @@ def _parse_config( from .config_agent_utils import resolve_callbacks from .config_agent_utils import resolve_code_reference - if config.model: + if config.model_code: + kwargs['model'] = resolve_code_reference(config.model_code) + elif config.model: kwargs['model'] = config.model if config.instruction: kwargs['instruction'] = config.instruction diff --git a/src/google/adk/agents/llm_agent_config.py b/src/google/adk/agents/llm_agent_config.py index 7d2493597e..59c6d58869 100644 --- a/src/google/adk/agents/llm_agent_config.py +++ b/src/google/adk/agents/llm_agent_config.py @@ -15,6 +15,7 @@ from __future__ import annotations import logging +from typing import Any from typing import List from typing import Literal from typing import Optional @@ -22,6 +23,7 @@ from google.genai import types from pydantic import ConfigDict from pydantic import Field +from pydantic import model_validator from ..tools.tool_configs import ToolConfig from .base_agent_config import BaseAgentConfig @@ -52,11 +54,48 @@ class LlmAgentConfig(BaseAgentConfig): model: Optional[str] = Field( default=None, description=( - 'Optional. LlmAgent.model. If not set, the model will be inherited' - ' from the ancestor.' + 'Optional. LlmAgent.model. Provide a model name string (e.g.' + ' "gemini-2.0-flash"). If not set, the model will be inherited from' + ' the ancestor. To construct a model instance from code, use' + ' model_code.' ), ) + model_code: Optional[CodeConfig] = Field( + default=None, + description=( + 'Optional. A CodeConfig that instantiates a BaseLlm implementation' + ' such as LiteLlm with custom arguments (API base, fallbacks,' + ' etc.). Cannot be set together with `model`.' + ), + ) + + @model_validator(mode='before') + @classmethod + def _normalize_model_code(cls, data: Any) -> dict[str, Any] | Any: + if not isinstance(data, dict): + return data + + model_value = data.get('model') + model_code = data.get('model_code') + if isinstance(model_value, dict) and model_code is None: + logger.warning( + 'Detected legacy `model` mapping. Use `model_code` to provide a' + ' CodeConfig for custom model construction.' + ) + data = dict(data) + data['model_code'] = model_value + data['model'] = None + + return data + + @model_validator(mode='after') + def _validate_model_sources(self) -> LlmAgentConfig: + if self.model and self.model_code: + raise ValueError('Only one of `model` or `model_code` should be set.') + + return self + instruction: str = Field( description=( 'Required. LlmAgent.instruction. Dynamic instructions with' diff --git a/tests/unittests/agents/test_agent_config.py b/tests/unittests/agents/test_agent_config.py index 3d8e9209f9..86fda7fc9b 100644 --- a/tests/unittests/agents/test_agent_config.py +++ b/tests/unittests/agents/test_agent_config.py @@ -29,6 +29,7 @@ from google.adk.agents.loop_agent import LoopAgent from google.adk.agents.parallel_agent import ParallelAgent from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.models.lite_llm import LiteLlm import pytest import yaml @@ -259,6 +260,53 @@ def test_agent_config_discriminator_llm_agent_with_sub_agents( assert config.root.agent_class == agent_class_value +def test_agent_config_litellm_model_with_custom_args(tmp_path: Path): + yaml_content = """\ +name: managed_api_agent +description: Agent using LiteLLM managed endpoint +instruction: Respond concisely. +model_code: + name: google.adk.models.lite_llm.LiteLlm + args: + - name: model + value: kimi/k2 + - name: api_base + value: https://proxy.litellm.ai/v1 +""" + config_file = tmp_path / "litellm_agent.yaml" + config_file.write_text(yaml_content) + + agent = config_agent_utils.from_config(str(config_file)) + + assert isinstance(agent, LlmAgent) + assert isinstance(agent.model, LiteLlm) + assert agent.model.model == "kimi/k2" + assert agent.model._additional_args.get("api_base") == ( + "https://proxy.litellm.ai/v1" + ) + + +def test_agent_config_legacy_model_mapping_still_supported(tmp_path: Path): + yaml_content = """\ +name: managed_api_agent +description: Agent using LiteLLM managed endpoint +instruction: Respond concisely. +model: + name: google.adk.models.lite_llm.LiteLlm + args: + - name: model + value: kimi/k2 +""" + config_file = tmp_path / "legacy_litellm_agent.yaml" + config_file.write_text(yaml_content) + + agent = config_agent_utils.from_config(str(config_file)) + + assert isinstance(agent, LlmAgent) + assert isinstance(agent.model, LiteLlm) + assert agent.model.model == "kimi/k2" + + def test_agent_config_discriminator_custom_agent(): class MyCustomAgentConfig(BaseAgentConfig): agent_class: Literal["mylib.agents.MyCustomAgent"] = ( From 8da61be45aa01f5f319e3d753cba25adcf27a5a0 Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Mon, 1 Dec 2025 18:08:46 -0800 Subject: [PATCH 52/63] fix: Flush pending transcriptions on turn/generation complete or interrupt for Gemini API The Gemini API may not always send an explicit transcription finished signal. This change ensures that any buffered input or output transcription text is yielded as a finished transcription when a turn is completed, generation is complete, or the session is interrupted. Also, refined the check for `event.partial` in runners.py to be more explicit. Co-authored-by: Hangfei Lin PiperOrigin-RevId: 839008606 --- .../live_bidi_streaming_single_agent/agent.py | 4 +- .../adk/models/gemini_llm_connection.py | 37 ++- src/google/adk/models/google_llm.py | 2 +- src/google/adk/runners.py | 83 +++++- .../models/test_gemini_llm_connection.py | 240 +++++++++++++++++- 5 files changed, 359 insertions(+), 7 deletions(-) diff --git a/contributing/samples/live_bidi_streaming_single_agent/agent.py b/contributing/samples/live_bidi_streaming_single_agent/agent.py index c295adc136..9246fca9d5 100755 --- a/contributing/samples/live_bidi_streaming_single_agent/agent.py +++ b/contributing/samples/live_bidi_streaming_single_agent/agent.py @@ -65,8 +65,8 @@ async def check_prime(nums: list[int]) -> str: root_agent = Agent( - # model='gemini-live-2.5-flash-preview-native-audio-09-2025', # vertex - model='gemini-2.5-flash-native-audio-preview-09-2025', # for AI studio + model='gemini-live-2.5-flash-preview-native-audio-09-2025', # vertex + # model='gemini-2.5-flash-native-audio-preview-09-2025', # for AI studio # key name='roll_dice_agent', description=( diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index 15e6ed9599..55d4b62e96 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -21,6 +21,7 @@ from google.genai import types from ..utils.context_utils import Aclosing +from ..utils.variant_utils import GoogleLLMVariant from .base_llm_connection import BaseLlmConnection from .llm_response import LlmResponse @@ -36,10 +37,15 @@ class GeminiLlmConnection(BaseLlmConnection): """The Gemini model connection.""" - def __init__(self, gemini_session: live.AsyncSession): + def __init__( + self, + gemini_session: live.AsyncSession, + api_backend: GoogleLLMVariant = GoogleLLMVariant.VERTEX_AI, + ): self._gemini_session = gemini_session self._input_transcription_text: str = '' self._output_transcription_text: str = '' + self._api_backend = api_backend async def send_history(self, history: list[types.Content]): """Sends the conversation history to the gemini model. @@ -171,6 +177,9 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: yield self.__build_full_text_response(text) text = '' yield llm_response + # Note: in some cases, tool_call may arrive before + # generation_complete, causing transcription to appear after + # tool_call in the session log. if message.server_content.input_transcription: if message.server_content.input_transcription.text: self._input_transcription_text += ( @@ -215,6 +224,32 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: partial=False, ) self._output_transcription_text = '' + # The Gemini API might not send a transcription finished signal. + # Instead, we rely on generation_complete, turn_complete or + # interrupted signals to flush any pending transcriptions. + if self._api_backend == GoogleLLMVariant.GEMINI_API and ( + message.server_content.interrupted + or message.server_content.turn_complete + or message.server_content.generation_complete + ): + if self._input_transcription_text: + yield LlmResponse( + input_transcription=types.Transcription( + text=self._input_transcription_text, + finished=True, + ), + partial=False, + ) + self._input_transcription_text = '' + if self._output_transcription_text: + yield LlmResponse( + output_transcription=types.Transcription( + text=self._output_transcription_text, + finished=True, + ), + partial=False, + ) + self._output_transcription_text = '' if message.server_content.turn_complete: if text: yield self.__build_full_text_response(text) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 93d802ecdc..6b21cf62c7 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -342,7 +342,7 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: async with self._live_api_client.aio.live.connect( model=llm_request.model, config=llm_request.live_connect_config ) as live_session: - yield GeminiLlmConnection(live_session) + yield GeminiLlmConnection(live_session, api_backend=self._api_backend) async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None: """Adapt the google computer use predefined functions to the adk computer use toolset.""" diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index db9828f66e..4cf5a29546 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -67,6 +67,23 @@ logger = logging.getLogger('google_adk.' + __name__) +def _is_tool_call_or_response(event: Event) -> bool: + return bool(event.get_function_calls() or event.get_function_responses()) + + +def _is_transcription(event: Event) -> bool: + return ( + event.input_transcription is not None + or event.output_transcription is not None + ) + + +def _has_non_empty_transcription_text(transcription) -> bool: + return bool( + transcription and transcription.text and transcription.text.strip() + ) + + class Runner: """The Runner class is used to run agents. @@ -626,6 +643,7 @@ async def _exec_with_plugin( invocation_context: The invocation context session: The current session execute_fn: A callable that returns an AsyncGenerator of Events + is_live_call: Whether this is a live call Yields: Events from the execution, including any generated by plugins @@ -651,13 +669,74 @@ async def _exec_with_plugin( yield early_exit_event else: # Step 2: Otherwise continue with normal execution + # Note for live/bidi: + # the transcription may arrive later then the action(function call + # event and thus function response event). In this case, the order of + # transcription and function call event will be wrong if we just + # append as it arrives. To address this, we should check if there is + # transcription going on. If there is transcription going on, we + # should hold on appending the function call event until the + # transcription is finished. The transcription in progress can be + # identified by checking if the transcription event is partial. When + # the next transcription event is not partial, it means the previous + # transcription is finished. Then if there is any buffered function + # call event, we should append them after this finished(non-parital) + # transcription event. + buffered_events: list[Event] = [] + is_transcribing: bool = False + async with Aclosing(execute_fn(invocation_context)) as agen: async for event in agen: - if not event.partial: - if self._should_append_event(event, is_live_call): + if is_live_call: + if event.partial and _is_transcription(event): + is_transcribing = True + if is_transcribing and _is_tool_call_or_response(event): + # only buffer function call and function response event which is + # non-partial + buffered_events.append(event) + continue + # Note for live/bidi: for audio response, it's considered as + # non-paritla event(event.partial=None) + # event.partial=False and event.partial=None are considered as + # non-partial event; event.partial=True is considered as partial + # event. + if event.partial is not True: + if _is_transcription(event) and ( + _has_non_empty_transcription_text(event.input_transcription) + or _has_non_empty_transcription_text( + event.output_transcription + ) + ): + # transcription end signal, append buffered events + is_transcribing = False + logger.debug( + 'Appending transcription finished event: %s', event + ) + if self._should_append_event(event, is_live_call): + await self.session_service.append_event( + session=session, event=event + ) + + for buffered_event in buffered_events: + logger.debug('Appending buffered event: %s', buffered_event) + await self.session_service.append_event( + session=session, event=buffered_event + ) + buffered_events = [] + else: + # non-transcription event or empty transcription event, for + # example, event that stores blob reference, should be appended. + if self._should_append_event(event, is_live_call): + logger.debug('Appending non-buffered event: %s', event) + await self.session_service.append_event( + session=session, event=event + ) + else: + if event.partial is not True: await self.session_service.append_event( session=session, event=event ) + # Step 3: Run the on_event callbacks to optionally modify the event. modified_event = await plugin_manager.run_on_event_callback( invocation_context=invocation_context, event=event diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 6d3e685748..190007603c 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -15,6 +15,7 @@ from unittest import mock from google.adk.models.gemini_llm_connection import GeminiLlmConnection +from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types import pytest @@ -28,7 +29,17 @@ def mock_gemini_session(): @pytest.fixture def gemini_connection(mock_gemini_session): """GeminiLlmConnection instance with mocked session.""" - return GeminiLlmConnection(mock_gemini_session) + return GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI + ) + + +@pytest.fixture +def gemini_api_connection(mock_gemini_session): + """GeminiLlmConnection instance with mocked session for Gemini API.""" + return GeminiLlmConnection( + mock_gemini_session, api_backend=GoogleLLMVariant.GEMINI_API + ) @pytest.fixture @@ -225,6 +236,227 @@ async def mock_receive_generator(): assert content_response.content == mock_content +@pytest.mark.asyncio +async def test_receive_transcript_finished_on_interrupt( + gemini_api_connection, + mock_gemini_session, +): + """Test receive finishes transcription on interrupt signal.""" + + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock.Mock() + message1.server_content.model_turn = None + message1.server_content.interrupted = False + message1.server_content.input_transcription = types.Transcription( + text='Hello', finished=False + ) + message1.server_content.output_transcription = None + message1.server_content.turn_complete = False + message1.server_content.generation_complete = False + message1.tool_call = None + message1.session_resumption_update = None + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = mock.Mock() + message2.server_content.model_turn = None + message2.server_content.interrupted = False + message2.server_content.input_transcription = None + message2.server_content.output_transcription = types.Transcription( + text='How can', finished=False + ) + message2.server_content.turn_complete = False + message2.server_content.generation_complete = False + message2.tool_call = None + message2.session_resumption_update = None + + message3 = mock.Mock() + message3.usage_metadata = None + message3.server_content = mock.Mock() + message3.server_content.model_turn = None + message3.server_content.interrupted = True + message3.server_content.input_transcription = None + message3.server_content.output_transcription = None + message3.server_content.turn_complete = False + message3.server_content.generation_complete = False + message3.tool_call = None + message3.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + yield message3 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_api_connection.receive()] + + assert len(responses) == 5 + assert responses[4].interrupted is True + + assert responses[0].input_transcription.text == 'Hello' + assert responses[0].input_transcription.finished is False + assert responses[0].partial is True + assert responses[1].output_transcription.text == 'How can' + assert responses[1].output_transcription.finished is False + assert responses[1].partial is True + assert responses[2].input_transcription.text == 'Hello' + assert responses[2].input_transcription.finished is True + assert responses[2].partial is False + assert responses[3].output_transcription.text == 'How can' + assert responses[3].output_transcription.finished is True + assert responses[3].partial is False + + +@pytest.mark.asyncio +async def test_receive_transcript_finished_on_generation_complete( + gemini_api_connection, + mock_gemini_session, +): + """Test receive finishes transcription on generation_complete signal.""" + + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock.Mock() + message1.server_content.model_turn = None + message1.server_content.interrupted = False + message1.server_content.input_transcription = types.Transcription( + text='Hello', finished=False + ) + message1.server_content.output_transcription = None + message1.server_content.turn_complete = False + message1.server_content.generation_complete = False + message1.tool_call = None + message1.session_resumption_update = None + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = mock.Mock() + message2.server_content.model_turn = None + message2.server_content.interrupted = False + message2.server_content.input_transcription = None + message2.server_content.output_transcription = types.Transcription( + text='How can', finished=False + ) + message2.server_content.turn_complete = False + message2.server_content.generation_complete = False + message2.tool_call = None + message2.session_resumption_update = None + + message3 = mock.Mock() + message3.usage_metadata = None + message3.server_content = mock.Mock() + message3.server_content.model_turn = None + message3.server_content.interrupted = False + message3.server_content.input_transcription = None + message3.server_content.output_transcription = None + message3.server_content.turn_complete = False + message3.server_content.generation_complete = True + message3.tool_call = None + message3.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + yield message3 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_api_connection.receive()] + + assert len(responses) == 4 + + assert responses[0].input_transcription.text == 'Hello' + assert responses[0].input_transcription.finished is False + assert responses[0].partial is True + assert responses[1].output_transcription.text == 'How can' + assert responses[1].output_transcription.finished is False + assert responses[1].partial is True + assert responses[2].input_transcription.text == 'Hello' + assert responses[2].input_transcription.finished is True + assert responses[2].partial is False + assert responses[3].output_transcription.text == 'How can' + assert responses[3].output_transcription.finished is True + assert responses[3].partial is False + + +@pytest.mark.asyncio +async def test_receive_transcript_finished_on_turn_complete( + gemini_api_connection, + mock_gemini_session, +): + """Test receive finishes transcription on interrupt or complete signals.""" + + message1 = mock.Mock() + message1.usage_metadata = None + message1.server_content = mock.Mock() + message1.server_content.model_turn = None + message1.server_content.interrupted = False + message1.server_content.input_transcription = types.Transcription( + text='Hello', finished=False + ) + message1.server_content.output_transcription = None + message1.server_content.turn_complete = False + message1.server_content.generation_complete = False + message1.tool_call = None + message1.session_resumption_update = None + + message2 = mock.Mock() + message2.usage_metadata = None + message2.server_content = mock.Mock() + message2.server_content.model_turn = None + message2.server_content.interrupted = False + message2.server_content.input_transcription = None + message2.server_content.output_transcription = types.Transcription( + text='How can', finished=False + ) + message2.server_content.turn_complete = False + message2.server_content.generation_complete = False + message2.tool_call = None + message2.session_resumption_update = None + + message3 = mock.Mock() + message3.usage_metadata = None + message3.server_content = mock.Mock() + message3.server_content.model_turn = None + message3.server_content.interrupted = False + message3.server_content.input_transcription = None + message3.server_content.output_transcription = None + message3.server_content.turn_complete = True + message3.server_content.generation_complete = False + message3.tool_call = None + message3.session_resumption_update = None + + async def mock_receive_generator(): + yield message1 + yield message2 + yield message3 + + receive_mock = mock.Mock(return_value=mock_receive_generator()) + mock_gemini_session.receive = receive_mock + + responses = [resp async for resp in gemini_api_connection.receive()] + + assert len(responses) == 5 + assert responses[4].turn_complete is True + + assert responses[0].input_transcription.text == 'Hello' + assert responses[0].input_transcription.finished is False + assert responses[0].partial is True + assert responses[1].output_transcription.text == 'How can' + assert responses[1].output_transcription.finished is False + assert responses[1].partial is True + assert responses[2].input_transcription.text == 'Hello' + assert responses[2].input_transcription.finished is True + assert responses[2].partial is False + assert responses[3].output_transcription.text == 'How can' + assert responses[3].output_transcription.finished is True + assert responses[3].partial is False + + @pytest.mark.asyncio async def test_receive_handles_input_transcription_fragments( gemini_connection, mock_gemini_session @@ -240,6 +472,7 @@ async def test_receive_handles_input_transcription_fragments( ) message1.server_content.output_transcription = None message1.server_content.turn_complete = False + message1.server_content.generation_complete = False message1.tool_call = None message1.session_resumption_update = None @@ -253,6 +486,7 @@ async def test_receive_handles_input_transcription_fragments( ) message2.server_content.output_transcription = None message2.server_content.turn_complete = False + message2.server_content.generation_complete = False message2.tool_call = None message2.session_resumption_update = None @@ -266,6 +500,7 @@ async def test_receive_handles_input_transcription_fragments( ) message3.server_content.output_transcription = None message3.server_content.turn_complete = False + message3.server_content.generation_complete = False message3.tool_call = None message3.session_resumption_update = None @@ -306,6 +541,7 @@ async def test_receive_handles_output_transcription_fragments( text='How can', finished=False ) message1.server_content.turn_complete = False + message1.server_content.generation_complete = False message1.tool_call = None message1.session_resumption_update = None @@ -319,6 +555,7 @@ async def test_receive_handles_output_transcription_fragments( text=' I help?', finished=False ) message2.server_content.turn_complete = False + message2.server_content.generation_complete = False message2.tool_call = None message2.session_resumption_update = None @@ -332,6 +569,7 @@ async def test_receive_handles_output_transcription_fragments( text=None, finished=True ) message3.server_content.turn_complete = False + message3.server_content.generation_complete = False message3.tool_call = None message3.session_resumption_update = None From 090711934f9cd4a3327a6df626167130a217e04a Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Tue, 2 Dec 2025 11:18:23 -0800 Subject: [PATCH 53/63] feat: add Spanner vector_store_similarity_search tool The vector_store_similarity_search tool performs similarity search against data in a Spanner vector store table, using the provided Spanner tool settings for configuration. PiperOrigin-RevId: 839352057 --- .../samples/spanner_rag_agent/README.md | 241 ++++++++++--- .../samples/spanner_rag_agent/agent.py | 240 +++---------- src/google/adk/tools/spanner/search_tool.py | 323 ++++++++++++++---- src/google/adk/tools/spanner/settings.py | 102 +++++- .../adk/tools/spanner/spanner_toolset.py | 12 + src/google/adk/tools/spanner/utils.py | 24 ++ .../tools/spanner/test_search_tool.py | 281 ++++++++++++--- .../spanner/test_spanner_tool_settings.py | 45 +++ .../tools/spanner/test_spanner_toolset.py | 48 +++ 9 files changed, 957 insertions(+), 359 deletions(-) diff --git a/contributing/samples/spanner_rag_agent/README.md b/contributing/samples/spanner_rag_agent/README.md index 8148983736..99b60794fe 100644 --- a/contributing/samples/spanner_rag_agent/README.md +++ b/contributing/samples/spanner_rag_agent/README.md @@ -57,9 +57,9 @@ model endpoint. CREATE MODEL EmbeddingsModel INPUT( content STRING(MAX), ) OUTPUT( -embeddings STRUCT, values ARRAY>, +embeddings STRUCT>, ) REMOTE OPTIONS ( -endpoint = '//aiplatform.googleapis.com/projects//locations/us-central1/publishers/google/models/text-embedding-004' +endpoint = '//aiplatform.googleapis.com/projects//locations//publishers/google/models/text-embedding-005' ); ``` @@ -187,40 +187,203 @@ type. ## Which tool to use and When? -There are a few options to perform similarity search (see the `agent.py` for -implementation details): - -1. Wraps the built-in `similarity_search` in the Spanner Toolset. - - - This provides an easy and controlled way to perform similarity search. - You can specify different configurations related to vector search based - on your need without having to figure out all the details for a vector - search query. - -2. Wraps the built-in `execute_sql` in the Spanner Toolset. - - - `execute_sql` is a lower-level tool that you can have more control over - with. With the flexibility, you can specify a complicated (parameterized) - SQL query for your need, and let the `LlmAgent` pass the parameters. - -3. Use the Spanner Toolset (and all the tools that come with it) directly. - - - The most flexible and generic way. Instead of fixing configurations via - code, you can also specify the configurations via `instruction` to - the `LlmAgent` and let LLM to decide which tool to use and what parameters - to pass to different tools. It might even combine different tools together! - Note that in this usage, SQL generation is powered by the LlmAgent, which - can be more suitable for data analysis and assistant scenarios. - - To restrict the ability of an `LlmAgent`, `SpannerToolSet` also supports - `tool_filter` to explicitly specify allowed tools. As an example, the - following code specifies that only `execute_sql` and `get_table_schema` - are allowed: - - ```py - toolset = SpannerToolset( - credentials_config=credentials_config, - tool_filter=["execute_sql", "get_table_schema"], - spanner_tool_settings=SpannerToolSettings(), - ) - ``` - +There are a few options to perform similarity search: + +1. Use the built-in `vector_store_similarity_search` in the Spanner Toolset with explicit `SpannerVectorStoreSettings` configuration. + + - This provides an easy way to perform similarity search. You can specify + different configurations related to vector search based on your Spanner + database vector store table setup. + + Example pseudocode (see the `agent.py` for details): + + ```py + from google.adk.agents.llm_agent import LlmAgent + from google.adk.tools.spanner.settings import Capabilities + from google.adk.tools.spanner.settings import SpannerToolSettings + from google.adk.tools.spanner.settings import SpannerVectorStoreSettings + from google.adk.tools.spanner.spanner_toolset import SpannerToolset + + # credentials_config = SpannerCredentialsConfig(...) + + # Define Spanner tool config with the vector store settings. + vector_store_settings = SpannerVectorStoreSettings( + project_id="", + instance_id="", + database_id="", + table_name="products", + content_column="productDescription", + embedding_column="productDescriptionEmbedding", + vector_length=768, + vertex_ai_embedding_model_name="text-embedding-005", + selected_columns=[ + "productId", + "productName", + "productDescription", + ], + nearest_neighbors_algorithm="EXACT_NEAREST_NEIGHBORS", + top_k=3, + distance_type="COSINE", + additional_filter="inventoryCount > 0", + ) + + tool_settings = SpannerToolSettings( + capabilities=[Capabilities.DATA_READ], + vector_store_settings=vector_store_settings, + ) + + # Get the Spanner toolset with the Spanner tool settings and credentials config. + spanner_toolset = SpannerToolset( + credentials_config=credentials_config, + spanner_tool_settings=tool_settings, + # Use `vector_store_similarity_search` only + tool_filter=["vector_store_similarity_search"], + ) + + root_agent = LlmAgent( + model="gemini-2.5-flash", + name="spanner_knowledge_base_agent", + description=( + "Agent to answer questions about product-specific recommendations." + ), + instruction=""" + You are a helpful assistant that answers user questions about product-specific recommendations. + 1. Always use the `vector_store_similarity_search` tool to find relevant information. + 2. If no relevant information is found, say you don't know. + 3. Present all the relevant information naturally and well formatted in your response. + """, + tools=[spanner_toolset], + ) + ``` + +2. Use the built-in `similarity_search` in the Spanner Toolset. + + - `similarity_search` is a lower-level tool, which provide the most flexible + and generic way. Specify all the necessary tool's parameters is required + when interacting with `LlmAgent` before performing the tool call. This is + more suitable for data analysis, ad-hoc query and assistant scenarios. + + Example pseudocode: + + ```py + from google.adk.agents.llm_agent import LlmAgent + from google.adk.tools.spanner.settings import Capabilities + from google.adk.tools.spanner.settings import SpannerToolSettings + from google.adk.tools.spanner.spanner_toolset import SpannerToolset + + # credentials_config = SpannerCredentialsConfig(...) + + tool_settings = SpannerToolSettings( + capabilities=[Capabilities.DATA_READ], + ) + + spanner_toolset = SpannerToolset( + credentials_config=credentials_config, + spanner_tool_settings=tool_settings, + # Use `similarity_search` only + tool_filter=["similarity_search"], + ) + + root_agent = LlmAgent( + model="gemini-2.5-flash", + name="spanner_knowledge_base_agent", + description=( + "Agent to answer questions by retrieving relevant information " + "from the Spanner database." + ), + instruction=""" + You are a helpful assistant that answers user questions to find the most relavant information from a Spanner database. + 1. Always use the `similarity_search` tool to find relevant information. + 2. If no relevant information is found, say you don't know. + 3. Present all the relevant information naturally and well formatted in your response. + """, + tools=[spanner_toolset], + ) + ``` + +3. Wraps the built-in `similarity_search` in the Spanner Toolset. + + - This provides a more controlled way to perform similarity search via code. + You can extend the tool as a wrapped function tool to have customized logic. + + Example pseudocode: + + ```py + from google.adk.agents.llm_agent import LlmAgent + + from google.adk.tools.google_tool import GoogleTool + from google.adk.tools.spanner import search_tool + import google.auth + from google.auth.credentials import Credentials + + # credentials_config = SpannerCredentialsConfig(...) + + # Create a wrapped function tool for the agent on top of the built-in + # similarity_search tool in the Spanner toolset. + # This customized tool is used to perform a Spanner KNN vector search on a + # embedded knowledge base stored in a Spanner database table. + def wrapped_spanner_similarity_search( + search_query: str, + credentials: Credentials, + ) -> str: + """Perform a similarity search on the product catalog. + + Args: + search_query: The search query to find relevant content. + + Returns: + Relevant product catalog content with sources + """ + + # ... Customized logic ... + + # Instead of fixing all parameters, you can also expose some of them for + # the LLM to decide. + return search_tool.similarity_search( + project_id="", + instance_id="", + database_id="", + table_name="products", + query=search_query, + embedding_column_to_search="productDescriptionEmbedding", + columns= [ + "productId", + "productName", + "productDescription", + ] + embedding_options={ + "vertex_ai_embedding_model_name": "text-embedding-005", + }, + credentials=credentials, + additional_filter="inventoryCount > 0", + search_options={ + "top_k": 3, + "distance_type": "EUCLIDEAN", + }, + ) + + # ... + + root_agent = LlmAgent( + model="gemini-2.5-flash", + name="spanner_knowledge_base_agent", + description=( + "Agent to answer questions about product-specific recommendations." + ), + instruction=""" + You are a helpful assistant that answers user questions about product-specific recommendations. + 1. Always use the `wrapped_spanner_similarity_search` tool to find relevant information. + 2. If no relevant information is found, say you don't know. + 3. Present all the relevant information naturally and well formatted in your response. + """, + tools=[ + # Add customized Spanner tool based on the built-in similarity_search + # in the Spanner toolset. + GoogleTool( + func=wrapped_spanner_similarity_search, + credentials_config=credentials_config, + tool_settings=tool_settings, + ), + ], + ) + ``` diff --git a/contributing/samples/spanner_rag_agent/agent.py b/contributing/samples/spanner_rag_agent/agent.py index 58ec95294e..1460242184 100644 --- a/contributing/samples/spanner_rag_agent/agent.py +++ b/contributing/samples/spanner_rag_agent/agent.py @@ -13,23 +13,15 @@ # limitations under the License. import os -from typing import Any -from typing import Dict -from typing import Optional from google.adk.agents.llm_agent import LlmAgent from google.adk.auth.auth_credential import AuthCredentialTypes -from google.adk.tools.base_tool import BaseTool -from google.adk.tools.google_tool import GoogleTool -from google.adk.tools.spanner import query_tool -from google.adk.tools.spanner import search_tool from google.adk.tools.spanner.settings import Capabilities from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.settings import SpannerVectorStoreSettings from google.adk.tools.spanner.spanner_credentials import SpannerCredentialsConfig -from google.adk.tools.tool_context import ToolContext +from google.adk.tools.spanner.spanner_toolset import SpannerToolset import google.auth -from google.auth.credentials import Credentials -from pydantic import BaseModel # Define an appropriate credential type # Set to None to use the application default credentials (ADC) for a quick @@ -37,9 +29,6 @@ CREDENTIALS_TYPE = None -# Define Spanner tool config with read capability set to allowed. -tool_settings = SpannerToolSettings(capabilities=[Capabilities.DATA_READ]) - if CREDENTIALS_TYPE == AuthCredentialTypes.OAUTH2: # Initialize the tools to do interactive OAuth # The environment variables OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET @@ -67,172 +56,46 @@ credentials=application_default_credentials ) +# Follow the instructions in README.md to set up the example Spanner database. +# Replace the following settings with your specific Spanner database. + +# Define Spanner vector store settings. +vector_store_settings = SpannerVectorStoreSettings( + project_id="", + instance_id="", + database_id="", + table_name="products", + content_column="productDescription", + embedding_column="productDescriptionEmbedding", + vector_length=768, + vertex_ai_embedding_model_name="text-embedding-005", + selected_columns=[ + "productId", + "productName", + "productDescription", + ], + nearest_neighbors_algorithm="EXACT_NEAREST_NEIGHBORS", + top_k=3, + distance_type="COSINE", + additional_filter="inventoryCount > 0", +) -### Section 1: Extending the built-in Spanner Toolset for Custom Use Cases ### -# This example illustrates how to extend the built-in Spanner toolset to create -# a customized Spanner tool. This method is advantageous when you need to deal -# with a specific use case: -# -# 1. Streamline the end user experience by pre-configuring the tool with fixed -# parameters (such as a specific database, instance, or project) and a -# dedicated SQL query, making it perfect for a single, focused use case -# like vector search on a specific table. -# 2. Enhance functionality by adding custom logic to manage tool inputs, -# execution, and result processing, providing greater control over the -# tool's behavior. -class SpannerRagSetting(BaseModel): - """Customized Spanner RAG settings for an example use case.""" - - # Replace the following settings for your Spanner database used in the sample. - project_id: str = "" - instance_id: str = "" - database_id: str = "" - - # Follow the instructions in README.md, the table name is "products" and the - # Spanner embedding model name is "EmbeddingsModel" in this sample. - table_name: str = "products" - # Learn more about Spanner Vertex AI integration for embedding and Spanner - # vector search. - # https://cloud.google.com/spanner/docs/ml-tutorial-embeddings - # https://cloud.google.com/spanner/docs/vector-search/overview - embedding_model_name: str = "EmbeddingsModel" - - selected_columns: list[str] = [ - "productId", - "productName", - "productDescription", - ] - embedding_column_name: str = "productDescriptionEmbedding" - - additional_filter_expression: str = "inventoryCount > 0" - vector_distance_function: str = "EUCLIDEAN_DISTANCE" - top_k: int = 3 - - -RAG_SETTINGS = SpannerRagSetting() - - -### (Option 1) Use the built-in similarity_search tool ### -# Create a wrapped function tool for the agent on top of the built-in -# similarity_search tool in the Spanner toolset. -# This customized tool is used to perform a Spanner KNN vector search on a -# embedded knowledge base stored in a Spanner database table. -def wrapped_spanner_similarity_search( - search_query: str, - credentials: Credentials, # GoogleTool handles `credentials` automatically - settings: SpannerToolSettings, # GoogleTool handles `settings` automatically - tool_context: ToolContext, # GoogleTool handles `tool_context` automatically -) -> str: - """Perform a similarity search on the product catalog. - - Args: - search_query: The search query to find relevant content. - - Returns: - Relevant product catalog content with sources - """ - columns = RAG_SETTINGS.selected_columns.copy() - - # Instead of fixing all parameters, you can also expose some of them for - # the LLM to decide. - return search_tool.similarity_search( - RAG_SETTINGS.project_id, - RAG_SETTINGS.instance_id, - RAG_SETTINGS.database_id, - RAG_SETTINGS.table_name, - search_query, - RAG_SETTINGS.embedding_column_name, - columns, - { - "spanner_embedding_model_name": RAG_SETTINGS.embedding_model_name, - }, - credentials, - settings, - tool_context, - RAG_SETTINGS.additional_filter_expression, - { - "top_k": RAG_SETTINGS.top_k, - "distance_type": RAG_SETTINGS.vector_distance_function, - }, - ) - - -### (Option 2) Use the built-in execute_sql tool ### -# Create a wrapped function tool for the agent on top of the built-in -# execute_sql tool in the Spanner toolset. -# This customized tool is used to perform a Spanner KNN vector search on a -# embedded knowledge base stored in a Spanner database table. -# -# Compared with similarity_search, using execute_sql (a lower level tool) means -# that you have more control, but you also need to do more work (e.g. to write -# the SQL query from scratch). Consider using this option if your scenario is -# more complicated than a plain similarity search. -def wrapped_spanner_execute_sql_tool( - search_query: str, - credentials: Credentials, # GoogleTool handles `credentials` automatically - settings: SpannerToolSettings, # GoogleTool handles `settings` automatically - tool_context: ToolContext, # GoogleTool handles `tool_context` automatically -) -> str: - """Perform a similarity search on the product catalog. - - Args: - search_query: The search query to find relevant content. - - Returns: - Relevant product catalog content with sources - """ - - embedding_query = f"""SELECT embeddings.values - FROM ML.PREDICT( - MODEL {RAG_SETTINGS.embedding_model_name}, - (SELECT "{search_query}" as content) - ) - """ - - distance_alias = "distance" - columns = [f"{column}" for column in RAG_SETTINGS.selected_columns] - columns += [f"""{RAG_SETTINGS.vector_distance_function}( - {RAG_SETTINGS.embedding_column_name}, - ({embedding_query})) AS {distance_alias} - """] - columns = ", ".join(columns) - - knn_query = f""" - SELECT {columns} - FROM {RAG_SETTINGS.table_name} - WHERE {RAG_SETTINGS.additional_filter_expression} - ORDER BY {distance_alias} - LIMIT {RAG_SETTINGS.top_k} - """ - - # Customized tool based on the built-in Spanner toolset. - return query_tool.execute_sql( - project_id=RAG_SETTINGS.project_id, - instance_id=RAG_SETTINGS.instance_id, - database_id=RAG_SETTINGS.database_id, - query=knn_query, - credentials=credentials, - settings=settings, - tool_context=tool_context, - ) - - -def inspect_tool_params( - tool: BaseTool, - args: Dict[str, Any], - tool_context: ToolContext, -) -> Optional[Dict]: - """A callback function to inspect tool parameters before execution.""" - print("Inspect for tool: " + tool.name) - - actual_search_query_in_args = args.get("search_query") - # Inspect the `search_query` when calling the tool for tutorial purposes. - print(f"Tool args `search_query`: {actual_search_query_in_args}") +# Define Spanner tool config with the vector store settings. +tool_settings = SpannerToolSettings( + capabilities=[Capabilities.DATA_READ], + vector_store_settings=vector_store_settings, +) - pass +# Get the Spanner toolset with the Spanner tool settings and credentials config. +# Filter the tools to only include the `vector_store_similarity_search` tool. +spanner_toolset = SpannerToolset( + credentials_config=credentials_config, + spanner_tool_settings=tool_settings, + # Comment to include all allowed tools. + tool_filter=["vector_store_similarity_search"], +) -### Section 2: Create the root agent ### root_agent = LlmAgent( model="gemini-2.5-flash", name="spanner_knowledge_base_agent", @@ -241,27 +104,10 @@ def inspect_tool_params( ), instruction=""" You are a helpful assistant that answers user questions about product-specific recommendations. - 1. Always use the `wrapped_spanner_similarity_search` tool to find relevant information. - 2. If no relevant information is found, say you don't know. - 3. Present all the relevant information naturally and well formatted in your response. + 1. Always use the `vector_store_similarity_search` tool to find information. + 2. Directly present all the information results from the `vector_store_similarity_search` tool naturally and well formatted in your response. + 3. If no information result is returned by the `vector_store_similarity_search` tool, say you don't know. """, - tools=[ - # # (Option 1) - # # Add customized Spanner tool based on the built-in similarity_search - # # in the Spanner toolset. - GoogleTool( - func=wrapped_spanner_similarity_search, - credentials_config=credentials_config, - tool_settings=tool_settings, - ), - # # (Option 2) - # # Add customized Spanner tool based on the built-in execute_sql in - # # the Spanner toolset. - # GoogleTool( - # func=wrapped_spanner_execute_sql_tool, - # credentials_config=credentials_config, - # tool_settings=tool_settings, - # ), - ], - before_tool_callback=inspect_tool_params, + # Use the Spanner toolset for vector similarity search. + tools=[spanner_toolset], ) diff --git a/src/google/adk/tools/spanner/search_tool.py b/src/google/adk/tools/spanner/search_tool.py index b3cf797edf..2b6b9777dc 100644 --- a/src/google/adk/tools/spanner/search_tool.py +++ b/src/google/adk/tools/spanner/search_tool.py @@ -20,23 +20,34 @@ from typing import List from typing import Optional -from google.adk.tools.spanner import client -from google.adk.tools.spanner.settings import SpannerToolSettings -from google.adk.tools.tool_context import ToolContext from google.auth.credentials import Credentials from google.cloud.spanner_admin_database_v1.types import DatabaseDialect from google.cloud.spanner_v1.database import Database -# Embedding options -_SPANNER_EMBEDDING_MODEL_NAME = "spanner_embedding_model_name" -_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT = "vertex_ai_embedding_model_endpoint" +from . import client +from . import utils +from .settings import APPROXIMATE_NEAREST_NEIGHBORS +from .settings import EXACT_NEAREST_NEIGHBORS +from .settings import SpannerToolSettings + +# Embedding model settings. +# Only for Spanner GoogleSQL dialect database, and use Spanner ML.PREDICT +# function. +_SPANNER_GSQL_EMBEDDING_MODEL_NAME = "spanner_googlesql_embedding_model_name" +# Only for Spanner PostgreSQL dialect database, and use spanner.ML_PREDICT_ROW +# to inferencing with Vertex AI embedding model endpoint. +_SPANNER_PG_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT = ( + "spanner_postgresql_vertex_ai_embedding_model_endpoint" +) +# For both Spanner GoogleSQL and PostgreSQL dialects, use Vertex AI embedding +# model to generate embeddings for vector similarity search. +_VERTEX_AI_EMBEDDING_MODEL_NAME = "vertex_ai_embedding_model_name" +_OUTPUT_DIMENSIONALITY = "output_dimensionality" # Search options _TOP_K = "top_k" _DISTANCE_TYPE = "distance_type" _NEAREST_NEIGHBORS_ALGORITHM = "nearest_neighbors_algorithm" -_EXACT_NEAREST_NEIGHBORS = "EXACT_NEAREST_NEIGHBORS" -_APPROXIMATE_NEAREST_NEIGHBORS = "APPROXIMATE_NEAREST_NEIGHBORS" _NUM_LEAVES_TO_SEARCH = "num_leaves_to_search" # Constants @@ -48,12 +59,12 @@ def _generate_googlesql_for_embedding_query( - spanner_embedding_model_name: str, + spanner_gsql_embedding_model_name: str, ) -> str: return f""" SELECT embeddings.values FROM ML.PREDICT( - MODEL {spanner_embedding_model_name}, + MODEL {spanner_gsql_embedding_model_name}, (SELECT CAST(@{_GOOGLESQL_PARAMETER_TEXT_QUERY} AS STRING) as content) ) """ @@ -61,37 +72,60 @@ def _generate_googlesql_for_embedding_query( def _generate_postgresql_for_embedding_query( vertex_ai_embedding_model_endpoint: str, + output_dimensionality: Optional[int], ) -> str: + instances_json = f""" + 'instances', + JSONB_BUILD_ARRAY( + JSONB_BUILD_OBJECT( + 'content', + ${_POSTGRESQL_PARAMETER_TEXT_QUERY}::TEXT + ) + ) + """ + + params_list = [] + if output_dimensionality is not None: + params_list.append(f""" + 'parameters', + JSONB_BUILD_OBJECT( + 'outputDimensionality', + {output_dimensionality} + ) + """) + + jsonb_build_args = ",\n".join([instances_json] + params_list) + return f""" - SELECT spanner.FLOAT32_ARRAY( spanner.ML_PREDICT_ROW( - '{vertex_ai_embedding_model_endpoint}', - JSONB_BUILD_OBJECT( - 'instances', - JSONB_BUILD_ARRAY( JSONB_BUILD_OBJECT( - 'content', - ${_POSTGRESQL_PARAMETER_TEXT_QUERY}::TEXT - )) + SELECT spanner.FLOAT32_ARRAY( + spanner.ML_PREDICT_ROW( + '{vertex_ai_embedding_model_endpoint}', + JSONB_BUILD_OBJECT( + {jsonb_build_args} + ) + ) -> 'predictions' -> 0 -> 'embeddings' -> 'values' ) - ) -> 'predictions'->0->'embeddings'->'values' ) """ def _get_embedding_for_query( database: Database, dialect: DatabaseDialect, - spanner_embedding_model_name: Optional[str], - vertex_ai_embedding_model_endpoint: Optional[str], + spanner_gsql_embedding_model_name: Optional[str], + spanner_pg_vertex_ai_embedding_model_endpoint: Optional[str], query: str, + output_dimensionality: Optional[int] = None, ) -> List[float]: """Gets the embedding for the query.""" if dialect == DatabaseDialect.POSTGRESQL: embedding_query = _generate_postgresql_for_embedding_query( - vertex_ai_embedding_model_endpoint + spanner_pg_vertex_ai_embedding_model_endpoint, + output_dimensionality, ) params = {f"p{_POSTGRESQL_PARAMETER_TEXT_QUERY}": query} else: embedding_query = _generate_googlesql_for_embedding_query( - spanner_embedding_model_name + spanner_gsql_embedding_model_name ) params = {_GOOGLESQL_PARAMETER_TEXT_QUERY: query} with database.snapshot() as snapshot: @@ -101,8 +135,8 @@ def _get_embedding_for_query( def _get_postgresql_distance_function(distance_type: str) -> str: return { - "COSINE_DISTANCE": "spanner.cosine_distance", - "EUCLIDEAN_DISTANCE": "spanner.euclidean_distance", + "COSINE": "spanner.cosine_distance", + "EUCLIDEAN": "spanner.euclidean_distance", "DOT_PRODUCT": "spanner.dot_product", }[distance_type] @@ -110,13 +144,13 @@ def _get_postgresql_distance_function(distance_type: str) -> str: def _get_googlesql_distance_function(distance_type: str, ann: bool) -> str: if ann: return { - "COSINE_DISTANCE": "APPROX_COSINE_DISTANCE", - "EUCLIDEAN_DISTANCE": "APPROX_EUCLIDEAN_DISTANCE", + "COSINE": "APPROX_COSINE_DISTANCE", + "EUCLIDEAN": "APPROX_EUCLIDEAN_DISTANCE", "DOT_PRODUCT": "APPROX_DOT_PRODUCT", }[distance_type] return { - "COSINE_DISTANCE": "COSINE_DISTANCE", - "EUCLIDEAN_DISTANCE": "EUCLIDEAN_DISTANCE", + "COSINE": "COSINE_DISTANCE", + "EUCLIDEAN": "EUCLIDEAN_DISTANCE", "DOT_PRODUCT": "DOT_PRODUCT", }[distance_type] @@ -172,7 +206,7 @@ def _generate_sql_for_ann( """Generates a SQL query for ANN search.""" if dialect == DatabaseDialect.POSTGRESQL: raise NotImplementedError( - f"{_APPROXIMATE_NEAREST_NEIGHBORS} is not supported for PostgreSQL" + f"{APPROXIMATE_NEAREST_NEIGHBORS} is not supported for PostgreSQL" " dialect." ) distance_function = _get_googlesql_distance_function(distance_type, ann=True) @@ -206,8 +240,6 @@ def similarity_search( columns: List[str], embedding_options: Dict[str, str], credentials: Credentials, - settings: SpannerToolSettings, - tool_context: ToolContext, additional_filter: Optional[str] = None, search_options: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: @@ -234,21 +266,34 @@ def similarity_search( columns (List[str]): A list of column names, representing the additional columns to return in the search results. embedding_options (Dict[str, str]): A dictionary of options to use for - the embedding service. The following options are supported: - - spanner_embedding_model_name: (For GoogleSQL dialect) The + the embedding service. **Exactly one of the following three keys + MUST be present in this dictionary**: + `vertex_ai_embedding_model_name`, `spanner_googlesql_embedding_model_name`, + or `spanner_postgresql_vertex_ai_embedding_model_endpoint`. + - vertex_ai_embedding_model_name (str): (Supported both **GoogleSQL and + PostgreSQL** dialects Spanner database) The name of a + public Vertex AI embedding model (e.g., `'text-embedding-005'`). + If specified, the tool generates embeddings client-side using the + Vertex AI embedding model. + - spanner_googlesql_embedding_model_name (str): (For GoogleSQL dialect) The name of the embedding model that is registered in Spanner via a `CREATE MODEL` statement. For more details, see https://cloud.google.com/spanner/docs/ml-tutorial-embeddings#generate_and_store_text_embeddings - - vertex_ai_embedding_model_endpoint: (For PostgreSQL dialect) - The fully qualified endpoint of the Vertex AI embedding model, - in the format of + If specified, embedding generation is performed using Spanner's + `ML.PREDICT` function. + - spanner_postgresql_vertex_ai_embedding_model_endpoint (str): + (For PostgreSQL dialect) The fully qualified endpoint of the Vertex AI + embedding model, in the format of `projects/$project/locations/$location/publishers/google/models/$model_name`, where $project is the project hosting the Vertex AI endpoint, $location is the location of the endpoint, and $model_name is the name of the text embedding model. + If specified, embedding generation is performed using Spanner's + `spanner.ML_PREDICT_ROW` function. + - output_dimensionality: Optional. The output dimensionality of the + embedding. If not specified, the embedding model's default output + dimensionality will be used. credentials (Credentials): The credentials to use for the request. - settings (SpannerToolSettings): The configuration for the tool. - tool_context (ToolContext): The context for the tool. additional_filter (Optional[str]): An optional filter to apply to the search query. If provided, this will be added to the WHERE clause of the final query. @@ -257,9 +302,9 @@ def similarity_search( - top_k: The number of most similar documents to return. The default value is 4. - distance_type: The distance type to use to perform the - similarity search. Valid values include "COSINE_DISTANCE", - "EUCLIDEAN_DISTANCE", and "DOT_PRODUCT". Default value is - "COSINE_DISTANCE". + similarity search. Valid values include "COSINE", + "EUCLIDEAN", and "DOT_PRODUCT". Default value is + "COSINE". - nearest_neighbors_algorithm: The nearest neighbors search algorithm to use. Valid values include "EXACT_NEAREST_NEIGHBORS" and "APPROXIMATE_NEAREST_NEIGHBORS". Default value is @@ -287,15 +332,13 @@ def similarity_search( ... embedding_column_to_search="product_description_embedding", ... columns=["product_name", "product_description", "price_in_cents"], ... credentials=credentials, - ... settings=settings, - ... tool_context=tool_context, ... additional_filter="price_in_cents < 100000", ... embedding_options={ - ... "spanner_embedding_model_name": "my_embedding_model" + ... "vertex_ai_embedding_model_name": "text-embedding-005" ... }, ... search_options={ ... "top_k": 2, - ... "distance_type": "COSINE_DISTANCE" + ... "distance_type": "COSINE" ... } ... ) { @@ -336,33 +379,68 @@ def similarity_search( embedding_options = {} if search_options is None: search_options = {} - spanner_embedding_model_name = embedding_options.get( - _SPANNER_EMBEDDING_MODEL_NAME + + exclusive_embedding_model_keys = { + _VERTEX_AI_EMBEDDING_MODEL_NAME, + _SPANNER_GSQL_EMBEDDING_MODEL_NAME, + _SPANNER_PG_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT, + } + if ( + len( + exclusive_embedding_model_keys.intersection( + embedding_options.keys() + ) + ) + != 1 + ): + raise ValueError("Exactly one embedding model option must be specified.") + + vertex_ai_embedding_model_name = embedding_options.get( + _VERTEX_AI_EMBEDDING_MODEL_NAME + ) + spanner_gsql_embedding_model_name = embedding_options.get( + _SPANNER_GSQL_EMBEDDING_MODEL_NAME + ) + spanner_pg_vertex_ai_embedding_model_endpoint = embedding_options.get( + _SPANNER_PG_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT ) if ( database.database_dialect == DatabaseDialect.GOOGLE_STANDARD_SQL - and spanner_embedding_model_name is None + and vertex_ai_embedding_model_name is None + and spanner_gsql_embedding_model_name is None ): raise ValueError( - f"embedding_options['{_SPANNER_EMBEDDING_MODEL_NAME}']" - " must be specified for GoogleSQL dialect." + f"embedding_options['{_VERTEX_AI_EMBEDDING_MODEL_NAME}'] or" + f" embedding_options['{_SPANNER_GSQL_EMBEDDING_MODEL_NAME}'] must be" + " specified for GoogleSQL dialect Spanner database." ) - vertex_ai_embedding_model_endpoint = embedding_options.get( - _VERTEX_AI_EMBEDDING_MODEL_ENDPOINT - ) if ( database.database_dialect == DatabaseDialect.POSTGRESQL - and vertex_ai_embedding_model_endpoint is None + and vertex_ai_embedding_model_name is None + and spanner_pg_vertex_ai_embedding_model_endpoint is None ): raise ValueError( - f"embedding_options['{_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT}']" - " must be specified for PostgreSQL dialect." + f"embedding_options['{_VERTEX_AI_EMBEDDING_MODEL_NAME}'] or" + f" embedding_options['{_SPANNER_PG_VERTEX_AI_EMBEDDING_MODEL_ENDPOINT}']" + " must be specified for PostgreSQL dialect Spanner database." + ) + output_dimensionality = embedding_options.get(_OUTPUT_DIMENSIONALITY) + if ( + output_dimensionality is not None + and spanner_gsql_embedding_model_name is not None + ): + # Currently, Spanner GSQL Model ML.PREDICT does not support + # output_dimensionality parameter for inference embedding models. + raise ValueError( + f"embedding_options[{_OUTPUT_DIMENSIONALITY}] is not supported when" + f" embedding_options['{_SPANNER_GSQL_EMBEDDING_MODEL_NAME}'] is" + " specified." ) # Use cosine distance by default. distance_type = search_options.get(_DISTANCE_TYPE) if distance_type is None: - distance_type = "COSINE_DISTANCE" + distance_type = "COSINE" top_k = search_options.get(_TOP_K) if top_k is None: @@ -370,26 +448,36 @@ def similarity_search( # Use EXACT_NEAREST_NEIGHBORS (i.e. kNN) by default. nearest_neighbors_algorithm = search_options.get( - _NEAREST_NEIGHBORS_ALGORITHM, _EXACT_NEAREST_NEIGHBORS + _NEAREST_NEIGHBORS_ALGORITHM, + EXACT_NEAREST_NEIGHBORS, ) if nearest_neighbors_algorithm not in ( - _EXACT_NEAREST_NEIGHBORS, - _APPROXIMATE_NEAREST_NEIGHBORS, + EXACT_NEAREST_NEIGHBORS, + APPROXIMATE_NEAREST_NEIGHBORS, ): raise NotImplementedError( f"Unsupported search_options['{_NEAREST_NEIGHBORS_ALGORITHM}']:" f" {nearest_neighbors_algorithm}" ) - embedding = _get_embedding_for_query( - database, - database.database_dialect, - spanner_embedding_model_name, - vertex_ai_embedding_model_endpoint, - query, - ) + # Generate embedding for the query according to the embedding options. + if vertex_ai_embedding_model_name: + embedding = utils.embed_contents( + vertex_ai_embedding_model_name, + [query], + output_dimensionality, + )[0] + else: + embedding = _get_embedding_for_query( + database, + database.database_dialect, + spanner_gsql_embedding_model_name, + spanner_pg_vertex_ai_embedding_model_endpoint, + query, + output_dimensionality, + ) - if nearest_neighbors_algorithm == _EXACT_NEAREST_NEIGHBORS: + if nearest_neighbors_algorithm == EXACT_NEAREST_NEIGHBORS: sql = _generate_sql_for_knn( database.database_dialect, table_name, @@ -438,5 +526,100 @@ def similarity_search( except Exception as ex: return { "status": "ERROR", - "error_details": str(ex), + "error_details": repr(ex), + } + + +def vector_store_similarity_search( + query: str, + credentials: Credentials, + settings: SpannerToolSettings, +) -> Dict[str, Any]: + """Performs a semantic similarity search to retrieve relevant context from the Spanner vector store. + + This function performs vector similarity search directly on a vector store + table in Spanner database and returns the relevant data. + + Args: + query (str): The search string based on the user's question. + credentials (Credentials): The credentials to use for the request. + settings (SpannerToolSettings): The configuration for the tool. + + Returns: + Dict[str, Any]: A dictionary representing the result of the search. + On success, it contains {"status": "SUCCESS", "rows": [...]}. The last + column of each row is the distance between the query and the row result. + On error, it contains {"status": "ERROR", "error_details": "..."}. + + Examples: + >>> vector_store_similarity_search( + ... query="Spanner database optimization techniques for high QPS", + ... credentials=credentials, + ... settings=settings + ... ) + { + "status": "SUCCESS", + "rows": [ + ( + "Optimizing Query Performance", + 0.12, + ), + ( + "Schema Design Best Practices", + 0.25, + ), + ( + "Using Secondary Indexes Effectively", + 0.31, + ), + ... + ], + } + """ + + try: + if not settings or not settings.vector_store_settings: + raise ValueError("Spanner vector store settings are not set.") + + # Get the embedding model settings. + embedding_options = { + _VERTEX_AI_EMBEDDING_MODEL_NAME: ( + settings.vector_store_settings.vertex_ai_embedding_model_name + ), + _OUTPUT_DIMENSIONALITY: settings.vector_store_settings.vector_length, + } + + # Get the search settings. + search_options = { + _TOP_K: settings.vector_store_settings.top_k, + _DISTANCE_TYPE: settings.vector_store_settings.distance_type, + _NEAREST_NEIGHBORS_ALGORITHM: ( + settings.vector_store_settings.nearest_neighbors_algorithm + ), + } + if ( + settings.vector_store_settings.nearest_neighbors_algorithm + == APPROXIMATE_NEAREST_NEIGHBORS + ): + search_options[_NUM_LEAVES_TO_SEARCH] = ( + settings.vector_store_settings.num_leaves_to_search + ) + + return similarity_search( + project_id=settings.vector_store_settings.project_id, + instance_id=settings.vector_store_settings.instance_id, + database_id=settings.vector_store_settings.database_id, + table_name=settings.vector_store_settings.table_name, + query=query, + embedding_column_to_search=settings.vector_store_settings.embedding_column, + columns=settings.vector_store_settings.selected_columns, + embedding_options=embedding_options, + credentials=credentials, + additional_filter=settings.vector_store_settings.additional_filter, + search_options=search_options, + ) + except Exception as ex: + return { + "status": "ERROR", + "error_details": repr(ex), } diff --git a/src/google/adk/tools/spanner/settings.py b/src/google/adk/tools/spanner/settings.py index 5d097258f4..a76331ba65 100644 --- a/src/google/adk/tools/spanner/settings.py +++ b/src/google/adk/tools/spanner/settings.py @@ -16,20 +16,115 @@ from enum import Enum from typing import List +from typing import Literal +from typing import Optional from pydantic import BaseModel +from pydantic import model_validator from ...utils.feature_decorator import experimental +# Vector similarity search nearest neighbors search algorithms. +EXACT_NEAREST_NEIGHBORS = "EXACT_NEAREST_NEIGHBORS" +APPROXIMATE_NEAREST_NEIGHBORS = "APPROXIMATE_NEAREST_NEIGHBORS" +NearestNeighborsAlgorithm = Literal[ + EXACT_NEAREST_NEIGHBORS, + APPROXIMATE_NEAREST_NEIGHBORS, +] + class Capabilities(Enum): """Capabilities indicating what type of operation tools are allowed to be performed on Spanner.""" - DATA_READ = 'data_read' + DATA_READ = "data_read" """Read only data operations tools are allowed.""" -@experimental('Tool settings defaults may have breaking change in the future.') +class SpannerVectorStoreSettings(BaseModel): + """Settings for Spanner Vector Store. + + This is used for vector similarity search in a Spanner vector store table. + Provide the vector store table and the embedding model settings to use with + the `vector_store_similarity_search` tool. + """ + + project_id: str + """Required. The GCP project id in which the Spanner database resides.""" + + instance_id: str + """Required. The instance id of the Spanner database.""" + + database_id: str + """Required. The database id of the Spanner database.""" + + table_name: str + """Required. The name of the vector store table to use for vector similarity search.""" + + content_column: str + """Required. The name of the content column in the vector store table. By default, this column value is also returned as part of the vector similarity search result.""" + + embedding_column: str + """Required. The name of the embedding column to search in the vector store table.""" + + vector_length: int + """Required. The the dimension of the vectors in the `embedding_column`.""" + + vertex_ai_embedding_model_name: str + """Required. The Vertex AI embedding model name, which is used to generate embeddings for vector store and vector similarity search. + For example, 'text-embedding-005'. + + Note: the output dimensionality of the embedding model should be the same as the value specified in the `vector_length` field. + Otherwise, a runtime error might be raised during a query. + """ + + selected_columns: List[str] = [] + """Required. The vector store table columns to return in the vector similarity search result. + + By default, only the `content_column` value and the distance value are returned. + If sepecified, the list of selected columns and the distance value are returned. + For example, if `selected_columns` is ['col1', 'col2'], then the result will contain the values of 'col1' and 'col2' columns and the distance value. + """ + + nearest_neighbors_algorithm: NearestNeighborsAlgorithm = ( + "EXACT_NEAREST_NEIGHBORS" + ) + """The algorithm used to perform vector similarity search. This value can be EXACT_NEAREST_NEIGHBORS or APPROXIMATE_NEAREST_NEIGHBORS. + + For more details about EXACT_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-k-nearest-neighbors + For more details about APPROXIMATE_NEAREST_NEIGHBORS, see https://docs.cloud.google.com/spanner/docs/find-approximate-nearest-neighbors + """ + + top_k: int = 4 + """Required. The number of neighbors to return for each vector similarity search query. The default value is 4.""" + + distance_type: str = "COSINE" + """Required. The distance metric used to build the vector index or perform vector similarity search. This value can be COSINE, DOT_PRODUCT, or EUCLIDEAN.""" + + num_leaves_to_search: Optional[int] = None + """Optional. This option specifies how many leaf nodes of the index are searched. + + Note: this option is only used when the nearest neighbors search algorithm (`nearest_neighbors_algorithm`) is APPROXIMATE_NEAREST_NEIGHBORS. + For more details, see https://docs.cloud.google.com/spanner/docs/vector-index-best-practices + """ + + additional_filter: Optional[str] = None + """Optional. An optional filter to apply to the search query. If provided, this will be added to the WHERE clause of the final query.""" + + @model_validator(mode="after") + def __post_init__(self): + """Validate the embedding settings.""" + if not self.vector_length or self.vector_length <= 0: + raise ValueError( + "Invalid vector length in the Spanner vector store settings." + ) + + if not self.selected_columns: + self.selected_columns = [self.content_column] + + return self + + +@experimental("Tool settings defaults may have breaking change in the future.") class SpannerToolSettings(BaseModel): """Settings for Spanner tools.""" @@ -44,3 +139,6 @@ class SpannerToolSettings(BaseModel): max_executed_query_result_rows: int = 50 """Maximum number of rows to return from a query result.""" + + vector_store_settings: Optional[SpannerVectorStoreSettings] = None + """Settings for Spanner vector store and vector similarity search.""" diff --git a/src/google/adk/tools/spanner/spanner_toolset.py b/src/google/adk/tools/spanner/spanner_toolset.py index 861314abb3..6496014f74 100644 --- a/src/google/adk/tools/spanner/spanner_toolset.py +++ b/src/google/adk/tools/spanner/spanner_toolset.py @@ -47,6 +47,8 @@ class SpannerToolset(BaseToolset): - spanner_list_named_schemas - spanner_get_table_schema - spanner_execute_sql + - spanner_similarity_search + - spanner_vector_store_similarity_search """ def __init__( @@ -121,6 +123,16 @@ async def get_tools( tool_settings=self._tool_settings, ) ) + if self._tool_settings.vector_store_settings: + # Only add the vector store similarity search tool if the vector store + # settings are specified. + all_tools.append( + GoogleTool( + func=search_tool.vector_store_similarity_search, + credentials_config=self._credentials_config, + tool_settings=self._tool_settings, + ) + ) return [ tool diff --git a/src/google/adk/tools/spanner/utils.py b/src/google/adk/tools/spanner/utils.py index 64c03859e1..f1b710ec85 100644 --- a/src/google/adk/tools/spanner/utils.py +++ b/src/google/adk/tools/spanner/utils.py @@ -105,3 +105,27 @@ def execute_sql( "status": "ERROR", "error_details": str(ex), } + + +def embed_contents( + vertex_ai_embedding_model_name: str, + contents: list[str], + output_dimensionality: Optional[int] = None, +) -> list[list[float]]: + """Embed the given contents into list of vectors using the Vertex AI embedding model endpoint.""" + try: + from google.genai import Client + from google.genai.types import EmbedContentConfig + + client = Client() + config = EmbedContentConfig() + if output_dimensionality: + config.output_dimensionality = output_dimensionality + response = client.models.embed_content( + model=vertex_ai_embedding_model_name, + contents=contents, + config=config, + ) + return [list(e.values) for e in response.embeddings] + except Exception as ex: + raise RuntimeError(f"Failed to embed content: {ex!r}") from ex diff --git a/tests/unittests/tools/spanner/test_search_tool.py b/tests/unittests/tools/spanner/test_search_tool.py index 1b330b01d5..c69aa444ec 100644 --- a/tests/unittests/tools/spanner/test_search_tool.py +++ b/tests/unittests/tools/spanner/test_search_tool.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock from unittest.mock import MagicMock -from unittest.mock import patch +from google.adk.tools.spanner import client from google.adk.tools.spanner import search_tool +from google.adk.tools.spanner import utils from google.cloud.spanner_admin_database_v1.types import DatabaseDialect import pytest @@ -35,29 +37,59 @@ def mock_spanner_ids(): } -@patch("google.adk.tools.spanner.client.get_spanner_client") +@pytest.mark.parametrize( + ("embedding_option_key", "embedding_option_value", "expected_embedding"), + [ + pytest.param( + "spanner_googlesql_embedding_model_name", + "EmbeddingsModel", + [0.1, 0.2, 0.3], + id="spanner_googlesql_embedding_model", + ), + pytest.param( + "vertex_ai_embedding_model_name", + "text-embedding-005", + [0.4, 0.5, 0.6], + id="vertex_ai_embedding_model", + ), + ], +) +@mock.patch.object(utils, "embed_contents") +@mock.patch.object(client, "get_spanner_client") def test_similarity_search_knn_success( - mock_get_spanner_client, mock_spanner_ids, mock_credentials + mock_get_spanner_client, + mock_embed_contents, + mock_spanner_ids, + mock_credentials, + embedding_option_key, + embedding_option_value, + expected_embedding, ): """Test similarity_search function with kNN success.""" mock_spanner_client = MagicMock() mock_instance = MagicMock() mock_database = MagicMock() mock_snapshot = MagicMock() - mock_embedding_result = MagicMock() - mock_embedding_result.one.return_value = ([0.1, 0.2, 0.3],) - # First call to execute_sql is for getting the embedding - # Second call is for the kNN search - mock_snapshot.execute_sql.side_effect = [ - mock_embedding_result, - iter([("result1",), ("result2",)]), - ] mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL mock_instance.database.return_value = mock_database mock_spanner_client.instance.return_value = mock_instance mock_get_spanner_client.return_value = mock_spanner_client + if embedding_option_key == "vertex_ai_embedding_model_name": + mock_embed_contents.return_value = [expected_embedding] + # execute_sql is called once for the kNN search + mock_snapshot.execute_sql.return_value = iter([("result1",), ("result2",)]) + else: + mock_embedding_result = MagicMock() + mock_embedding_result.one.return_value = (expected_embedding,) + # First call to execute_sql is for getting the embedding, + # second call is for the kNN search + mock_snapshot.execute_sql.side_effect = [ + mock_embedding_result, + iter([("result1",), ("result2",)]), + ] + result = search_tool.similarity_search( project_id=mock_spanner_ids["project_id"], instance_id=mock_spanner_ids["instance_id"], @@ -66,10 +98,8 @@ def test_similarity_search_knn_success( query="test query", embedding_column_to_search="embedding_col", columns=["col1"], - embedding_options={"spanner_embedding_model_name": "test_model"}, + embedding_options={embedding_option_key: embedding_option_value}, credentials=mock_credentials, - settings=MagicMock(), - tool_context=MagicMock(), ) assert result["status"] == "SUCCESS", result assert result["rows"] == [("result1",), ("result2",)] @@ -79,10 +109,14 @@ def test_similarity_search_knn_success( sql = call_args.args[0] assert "COSINE_DISTANCE" in sql assert "@embedding" in sql - assert call_args.kwargs == {"params": {"embedding": [0.1, 0.2, 0.3]}} + assert call_args.kwargs == {"params": {"embedding": expected_embedding}} + if embedding_option_key == "vertex_ai_embedding_model_name": + mock_embed_contents.assert_called_once_with( + embedding_option_value, ["test query"], None + ) -@patch("google.adk.tools.spanner.client.get_spanner_client") +@mock.patch.object(client, "get_spanner_client") def test_similarity_search_ann_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -113,10 +147,10 @@ def test_similarity_search_ann_success( query="test query", embedding_column_to_search="embedding_col", columns=["col1"], - embedding_options={"spanner_embedding_model_name": "test_model"}, + embedding_options={ + "spanner_googlesql_embedding_model_name": "test_model" + }, credentials=mock_credentials, - settings=MagicMock(), - tool_context=MagicMock(), search_options={ "nearest_neighbors_algorithm": "APPROXIMATE_NEAREST_NEIGHBORS" }, @@ -130,7 +164,7 @@ def test_similarity_search_ann_success( assert call_args.kwargs == {"params": {"embedding": [0.1, 0.2, 0.3]}} -@patch("google.adk.tools.spanner.client.get_spanner_client") +@mock.patch.object(client, "get_spanner_client") def test_similarity_search_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -143,17 +177,17 @@ def test_similarity_search_error( table_name=mock_spanner_ids["table_name"], query="test query", embedding_column_to_search="embedding_col", - embedding_options={"spanner_embedding_model_name": "test_model"}, + embedding_options={ + "spanner_googlesql_embedding_model_name": "test_model" + }, columns=["col1"], credentials=mock_credentials, - settings=MagicMock(), - tool_context=MagicMock(), ) assert result["status"] == "ERROR" - assert result["error_details"] == "Test Exception" + assert "Test Exception" in result["error_details"] -@patch("google.adk.tools.spanner.client.get_spanner_client") +@mock.patch.object(client, "get_spanner_client") def test_similarity_search_postgresql_knn_success( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -182,10 +216,12 @@ def test_similarity_search_postgresql_knn_success( query="test query", embedding_column_to_search="embedding_col", columns=["col1"], - embedding_options={"vertex_ai_embedding_model_endpoint": "test_endpoint"}, + embedding_options={ + "spanner_postgresql_vertex_ai_embedding_model_endpoint": ( + "test_endpoint" + ) + }, credentials=mock_credentials, - settings=MagicMock(), - tool_context=MagicMock(), ) assert result["status"] == "SUCCESS", result assert result["rows"] == [("pg_result",)] @@ -196,7 +232,7 @@ def test_similarity_search_postgresql_knn_success( assert call_args.kwargs == {"params": {"p1": [0.1, 0.2, 0.3]}} -@patch("google.adk.tools.spanner.client.get_spanner_client") +@mock.patch.object(client, "get_spanner_client") def test_similarity_search_postgresql_ann_unsupported( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): @@ -217,27 +253,28 @@ def test_similarity_search_postgresql_ann_unsupported( query="test query", embedding_column_to_search="embedding_col", columns=["col1"], - embedding_options={"vertex_ai_embedding_model_endpoint": "test_endpoint"}, + embedding_options={ + "spanner_postgresql_vertex_ai_embedding_model_endpoint": ( + "test_endpoint" + ) + }, credentials=mock_credentials, - settings=MagicMock(), - tool_context=MagicMock(), search_options={ "nearest_neighbors_algorithm": "APPROXIMATE_NEAREST_NEIGHBORS" }, ) assert result["status"] == "ERROR" assert ( - result["error_details"] - == "APPROXIMATE_NEAREST_NEIGHBORS is not supported for PostgreSQL" - " dialect." + "APPROXIMATE_NEAREST_NEIGHBORS is not supported for PostgreSQL dialect." + in result["error_details"] ) -@patch("google.adk.tools.spanner.client.get_spanner_client") -def test_similarity_search_missing_spanner_embedding_model_name_error( +@mock.patch.object(client, "get_spanner_client") +def test_similarity_search_gsql_missing_embedding_model_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): - """Test similarity_search with missing spanner_embedding_model_name.""" + """Test similarity_search with missing embedding_options for GoogleSQL dialect.""" mock_spanner_client = MagicMock() mock_instance = MagicMock() mock_database = MagicMock() @@ -254,24 +291,27 @@ def test_similarity_search_missing_spanner_embedding_model_name_error( query="test query", embedding_column_to_search="embedding_col", columns=["col1"], - embedding_options={}, + embedding_options={ + "spanner_postgresql_vertex_ai_embedding_model_endpoint": ( + "test_endpoint" + ) + }, credentials=mock_credentials, - settings=MagicMock(), - tool_context=MagicMock(), ) assert result["status"] == "ERROR" assert ( - "embedding_options['spanner_embedding_model_name'] must be" - " specified for GoogleSQL dialect." + "embedding_options['vertex_ai_embedding_model_name'] or" + " embedding_options['spanner_googlesql_embedding_model_name'] must be" + " specified for GoogleSQL dialect Spanner database." in result["error_details"] ) -@patch("google.adk.tools.spanner.client.get_spanner_client") -def test_similarity_search_missing_vertex_ai_embedding_model_endpoint_error( +@mock.patch.object(client, "get_spanner_client") +def test_similarity_search_pg_missing_embedding_model_error( mock_get_spanner_client, mock_spanner_ids, mock_credentials ): - """Test similarity_search with missing vertex_ai_embedding_model_endpoint.""" + """Test similarity_search with missing embedding_options for PostgreSQL dialect.""" mock_spanner_client = MagicMock() mock_instance = MagicMock() mock_database = MagicMock() @@ -288,14 +328,153 @@ def test_similarity_search_missing_vertex_ai_embedding_model_endpoint_error( query="test query", embedding_column_to_search="embedding_col", columns=["col1"], - embedding_options={}, + embedding_options={ + "spanner_googlesql_embedding_model_name": "EmbeddingsModel" + }, + credentials=mock_credentials, + ) + assert result["status"] == "ERROR" + assert ( + "embedding_options['vertex_ai_embedding_model_name'] or" + " embedding_options['spanner_postgresql_vertex_ai_embedding_model_endpoint']" + " must be specified for PostgreSQL dialect Spanner database." + in result["error_details"] + ) + + +@pytest.mark.parametrize( + "embedding_options", + [ + pytest.param( + { + "vertex_ai_embedding_model_name": "test-model", + "spanner_googlesql_embedding_model_name": "test-model-2", + }, + id="vertex_ai_and_googlesql", + ), + pytest.param( + { + "vertex_ai_embedding_model_name": "test-model", + "spanner_postgresql_vertex_ai_embedding_model_endpoint": ( + "test-endpoint" + ), + }, + id="vertex_ai_and_postgresql", + ), + pytest.param( + { + "spanner_googlesql_embedding_model_name": "test-model", + "spanner_postgresql_vertex_ai_embedding_model_endpoint": ( + "test-endpoint" + ), + }, + id="googlesql_and_postgresql", + ), + pytest.param( + { + "vertex_ai_embedding_model_name": "test-model", + "spanner_googlesql_embedding_model_name": "test-model-2", + "spanner_postgresql_vertex_ai_embedding_model_endpoint": ( + "test-endpoint" + ), + }, + id="all_three_models", + ), + pytest.param( + {}, + id="no_models", + ), + ], +) +@mock.patch.object(client, "get_spanner_client") +def test_similarity_search_multiple_embedding_options_error( + mock_get_spanner_client, + mock_spanner_ids, + mock_credentials, + embedding_options, +): + """Test similarity_search with multiple embedding models.""" + mock_spanner_client = MagicMock() + mock_instance = MagicMock() + mock_database = MagicMock() + mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL + mock_instance.database.return_value = mock_database + mock_spanner_client.instance.return_value = mock_instance + mock_get_spanner_client.return_value = mock_spanner_client + + result = search_tool.similarity_search( + project_id=mock_spanner_ids["project_id"], + instance_id=mock_spanner_ids["instance_id"], + database_id=mock_spanner_ids["database_id"], + table_name=mock_spanner_ids["table_name"], + query="test query", + embedding_column_to_search="embedding_col", + columns=["col1"], + embedding_options=embedding_options, credentials=mock_credentials, - settings=MagicMock(), - tool_context=MagicMock(), ) assert result["status"] == "ERROR" assert ( - "embedding_options['vertex_ai_embedding_model_endpoint'] must " - "be specified for PostgreSQL dialect." + "Exactly one embedding model option must be specified." in result["error_details"] ) + + +@mock.patch.object(client, "get_spanner_client") +def test_similarity_search_output_dimensionality_gsql_error( + mock_get_spanner_client, mock_spanner_ids, mock_credentials +): + """Test similarity_search with output_dimensionality and spanner_googlesql_embedding_model_name.""" + mock_spanner_client = MagicMock() + mock_instance = MagicMock() + mock_database = MagicMock() + mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL + mock_instance.database.return_value = mock_database + mock_spanner_client.instance.return_value = mock_instance + mock_get_spanner_client.return_value = mock_spanner_client + + result = search_tool.similarity_search( + project_id=mock_spanner_ids["project_id"], + instance_id=mock_spanner_ids["instance_id"], + database_id=mock_spanner_ids["database_id"], + table_name=mock_spanner_ids["table_name"], + query="test query", + embedding_column_to_search="embedding_col", + columns=["col1"], + embedding_options={ + "spanner_googlesql_embedding_model_name": "EmbeddingsModel", + "output_dimensionality": 128, + }, + credentials=mock_credentials, + ) + assert result["status"] == "ERROR" + assert "is not supported when" in result["error_details"] + + +@mock.patch.object(client, "get_spanner_client") +def test_similarity_search_unsupported_algorithm_error( + mock_get_spanner_client, mock_spanner_ids, mock_credentials +): + """Test similarity_search with an unsupported nearest neighbors algorithm.""" + mock_spanner_client = MagicMock() + mock_instance = MagicMock() + mock_database = MagicMock() + mock_database.database_dialect = DatabaseDialect.GOOGLE_STANDARD_SQL + mock_instance.database.return_value = mock_database + mock_spanner_client.instance.return_value = mock_instance + mock_get_spanner_client.return_value = mock_spanner_client + + result = search_tool.similarity_search( + project_id=mock_spanner_ids["project_id"], + instance_id=mock_spanner_ids["instance_id"], + database_id=mock_spanner_ids["database_id"], + table_name=mock_spanner_ids["table_name"], + query="test query", + embedding_column_to_search="embedding_col", + columns=["col1"], + embedding_options={"vertex_ai_embedding_model_name": "test-model"}, + credentials=mock_credentials, + search_options={"nearest_neighbors_algorithm": "INVALID_ALGORITHM"}, + ) + assert result["status"] == "ERROR" + assert "Unsupported search_options" in result["error_details"] diff --git a/tests/unittests/tools/spanner/test_spanner_tool_settings.py b/tests/unittests/tools/spanner/test_spanner_tool_settings.py index f74922b248..730c9e0efe 100644 --- a/tests/unittests/tools/spanner/test_spanner_tool_settings.py +++ b/tests/unittests/tools/spanner/test_spanner_tool_settings.py @@ -15,9 +15,23 @@ from __future__ import annotations from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.settings import SpannerVectorStoreSettings +from pydantic import ValidationError import pytest +def common_spanner_vector_store_settings(vector_length=None): + return { + "project_id": "test-project", + "instance_id": "test-instance", + "database_id": "test-database", + "table_name": "test-table", + "content_column": "test-content-column", + "embedding_column": "test-embedding-column", + "vector_length": 128 if vector_length is None else vector_length, + } + + def test_spanner_tool_settings_experimental_warning(): """Test SpannerToolSettings experimental warning.""" with pytest.warns( @@ -25,3 +39,34 @@ def test_spanner_tool_settings_experimental_warning(): match="Tool settings defaults may have breaking change in the future.", ): SpannerToolSettings() + + +def test_spanner_vector_store_settings_all_fields_present(): + """Test SpannerVectorStoreSettings with all required fields present.""" + settings = SpannerVectorStoreSettings( + **common_spanner_vector_store_settings(), + vertex_ai_embedding_model_name="test-embedding-model", + ) + assert settings is not None + assert settings.selected_columns == ["test-content-column"] + assert settings.vertex_ai_embedding_model_name == "test-embedding-model" + + +def test_spanner_vector_store_settings_missing_embedding_model_name(): + """Test SpannerVectorStoreSettings with missing vertex_ai_embedding_model_name.""" + with pytest.raises(ValidationError) as excinfo: + SpannerVectorStoreSettings(**common_spanner_vector_store_settings()) + assert "Field required" in str(excinfo.value) + assert "vertex_ai_embedding_model_name" in str(excinfo.value) + + +def test_spanner_vector_store_settings_invalid_vector_length(): + """Test SpannerVectorStoreSettings with invalid vector_length.""" + with pytest.raises(ValidationError) as excinfo: + SpannerVectorStoreSettings( + **common_spanner_vector_store_settings(vector_length=0), + vertex_ai_embedding_model_name="test-embedding-model", + ) + assert "Invalid vector length in the Spanner vector store settings." in str( + excinfo.value + ) diff --git a/tests/unittests/tools/spanner/test_spanner_toolset.py b/tests/unittests/tools/spanner/test_spanner_toolset.py index 163832559d..a583a2f884 100644 --- a/tests/unittests/tools/spanner/test_spanner_toolset.py +++ b/tests/unittests/tools/spanner/test_spanner_toolset.py @@ -18,6 +18,7 @@ from google.adk.tools.spanner import SpannerCredentialsConfig from google.adk.tools.spanner import SpannerToolset from google.adk.tools.spanner.settings import SpannerToolSettings +from google.adk.tools.spanner.settings import SpannerVectorStoreSettings import pytest @@ -184,3 +185,50 @@ async def test_spanner_toolset_without_read_capability( expected_tool_names = set(returned_tools) actual_tool_names = set([tool.name for tool in tools]) assert actual_tool_names == expected_tool_names + + +@pytest.mark.asyncio +async def test_spanner_toolset_with_vector_store_search(): + """Test Spanner toolset with vector store search. + + This test verifies the behavior of the Spanner toolset when vector store + settings is provided. + """ + credentials_config = SpannerCredentialsConfig( + client_id="abc", client_secret="def" + ) + + spanner_tool_settings = SpannerToolSettings( + vector_store_settings=SpannerVectorStoreSettings( + project_id="test-project", + instance_id="test-instance", + database_id="test-database", + table_name="test-table", + content_column="test-content-column", + embedding_column="test-embedding-column", + vector_length=128, + vertex_ai_embedding_model_name="test-embedding-model", + ) + ) + toolset = SpannerToolset( + credentials_config=credentials_config, + spanner_tool_settings=spanner_tool_settings, + ) + tools = await toolset.get_tools() + assert tools is not None + + assert len(tools) == 8 + assert all([isinstance(tool, GoogleTool) for tool in tools]) + + expected_tool_names = set([ + "list_table_names", + "list_table_indexes", + "list_table_index_columns", + "list_named_schemas", + "get_table_schema", + "execute_sql", + "similarity_search", + "vector_store_similarity_search", + ]) + actual_tool_names = set([tool.name for tool in tools]) + assert actual_tool_names == expected_tool_names From 77401132d1b0fb4d2db0ffab00590c1e3961156a Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Tue, 2 Dec 2025 12:15:15 -0800 Subject: [PATCH 54/63] chore: Add migration guide for DatabaseSessionService Co-authored-by: Shangjie Chen PiperOrigin-RevId: 839376632 --- src/google/adk/sessions/migration/README.md | 109 ++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 src/google/adk/sessions/migration/README.md diff --git a/src/google/adk/sessions/migration/README.md b/src/google/adk/sessions/migration/README.md new file mode 100644 index 0000000000..6a9079534e --- /dev/null +++ b/src/google/adk/sessions/migration/README.md @@ -0,0 +1,109 @@ +# Process for a New Schema Version + +This document outlines the steps required to introduce a new database schema +version for `DatabaseSessionService`. Let's assume you are introducing schema +version `2.0`, migrating from `1.0`. + +## 1. Update SQLAlchemy Models + +Modify the SQLAlchemy model classes (`StorageSession`, `StorageEvent`, +`StorageAppState`, `StorageUserState`, `StorageMetadata`) in +`database_session_service.py` to reflect the new `2.0` schema. This could +involve adding new `mapped_column` definitions, changing types, or adding new +classes for new tables. + +## 2. Create a New Migration Script + +You need to create a script that migrates data from schema `1.0` to `2.0`. + +* Create a new file, for example: + `google/adk/sessions/migration/migrate_1_0_to_2_0.py`. +* This script must contain a `migrate(source_db_url: str, dest_db_url: str)` + function, similar to `migrate_from_sqlalchemy_pickle.py`. +* Inside this function: + * Connect to the `source_db_url` (which has schema 1.0) and `dest_db_url` + engines using SQLAlchemy. + * **Important**: Create the tables in the destination database using the + new 2.0 schema definition by calling + `dss.Base.metadata.create_all(dest_engine)`. + * Read data from the source tables (schema 1.0). The recommended way to do + this without relying on outdated models is to use `sqlalchemy.text`, + like: + + ```python + from sqlalchemy import text + ... + rows = source_session.execute(text("SELECT * FROM sessions")).mappings().all() + ``` + + * For each row read from the source, transform the data as necessary to + fit the `2.0` schema, and create an instance of the corresponding new + SQLAlchemy model (e.g., `dss.StorageSession(...)`). + * Add these new `2.0` objects to the destination session, ideally using + `dest_session.merge()` to upsert. + * After migrating data for all tables, ensure the destination database is + marked with the new schema version: + + ```python + from google.adk.sessions import database_session_service as dss + from google.adk.sessions.migration import _schema_check + ... + dest_session.merge( + dss.StorageMetadata( + key=_schema_check.SCHEMA_VERSION_KEY, + value="2.0", + ) + ) + dest_session.commit() + ``` + +## 3. Update Schema Version Constant + +You need to update `CURRENT_SCHEMA_VERSION` in +`google/adk/sessions/migration/_schema_check.py` to reflect the new version: + +```python +CURRENT_SCHEMA_VERSION = "2.0" +``` + +This will also update `LATEST_VERSION` in `migration_runner.py`, as it uses this +constant. + +## 4. Register the New Migration in Migration Runner + +In `google/adk/sessions/migration/migration_runner.py`, import your new +migration script and add it to the `MIGRATIONS` dictionary. This tells the +runner how to get from version `1.0` to `2.0`. For example: + +```python +from google.adk.sessions.migration import _schema_check +from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle +from google.adk.sessions.migration import migrate_1_0_to_2_0 +... +MIGRATIONS = { + _schema_check.SCHEMA_VERSION_0_1_PICKLE: ( + _schema_check.SCHEMA_VERSION_1_0_JSON, + migrate_from_sqlalchemy_pickle.migrate, + ), + _schema_check.SCHEMA_VERSION_1_0_JSON: ( + "2.0", + migrate_1_0_to_2_0.migrate, + ), +} +``` + +## 5. Update `DatabaseSessionService` Business Logic + +If your schema change affects how data should be read or written during normal +operation (e.g., you added a new column that needs to be populated on session +creation), update the methods within `DatabaseSessionService` (`create_session`, +`get_session`, `append_event`, etc.) in `database_session_service.py` +accordingly. + +## 6. CLI Command Changes + +No changes are needed for the Click command definition in `cli_tools_click.py`. +The `adk migrate session` command calls `migration_runner.upgrade()`, which will +now automatically detect the source database version and apply the necessary +migration steps (e.g., `0.1 -> 1.0 -> 2.0`, or `1.0 -> 2.0`) to reach +`LATEST_VERSION`. \ No newline at end of file From 3aef9a18b1f8b54cc1712af7b3477049e4f52932 Mon Sep 17 00:00:00 2001 From: Xuan Yang Date: Tue, 2 Dec 2025 12:54:11 -0800 Subject: [PATCH 55/63] docs: Update ADK issue triaging agent to add component label before planned Co-authored-by: Xuan Yang PiperOrigin-RevId: 839391092 --- .github/workflows/triage.yml | 18 ++- .../samples/adk_triaging_agent/README.md | 44 +++++-- .../samples/adk_triaging_agent/agent.py | 117 +++++++++++++----- .../samples/adk_triaging_agent/main.py | 58 +++++---- 4 files changed, 174 insertions(+), 63 deletions(-) diff --git a/.github/workflows/triage.yml b/.github/workflows/triage.yml index 97ddc0efa7..46153f413a 100644 --- a/.github/workflows/triage.yml +++ b/.github/workflows/triage.yml @@ -2,16 +2,22 @@ name: ADK Issue Triaging Agent on: issues: - types: [labeled] + types: [opened, labeled] schedule: - # Run every 6 hours to triage planned but not triaged issues + # Run every 6 hours to triage untriaged issues - cron: '0 */6 * * *' jobs: agent-triage-issues: runs-on: ubuntu-latest - # Only run if labeled with "planned" or if it's a scheduled run - if: github.event_name == 'schedule' || github.event.label.name == 'planned' + # Run for: + # - Scheduled runs (batch processing) + # - New issues (need component labeling) + # - Issues labeled with "planned" (need owner assignment) + if: >- + github.event_name == 'schedule' || + github.event.action == 'opened' || + github.event.label.name == 'planned' permissions: issues: write contents: read @@ -35,8 +41,8 @@ jobs: GITHUB_TOKEN: ${{ secrets.ADK_TRIAGE_AGENT }} GOOGLE_API_KEY: ${{ secrets.GOOGLE_API_KEY }} GOOGLE_GENAI_USE_VERTEXAI: 0 - OWNER: 'google' - REPO: 'adk-python' + OWNER: ${{ github.repository_owner }} + REPO: ${{ github.event.repository.name }} INTERACTIVE: 0 EVENT_NAME: ${{ github.event_name }} # 'issues', 'schedule', etc. ISSUE_NUMBER: ${{ github.event.issue.number }} diff --git a/contributing/samples/adk_triaging_agent/README.md b/contributing/samples/adk_triaging_agent/README.md index be4071b61b..0420dba718 100644 --- a/contributing/samples/adk_triaging_agent/README.md +++ b/contributing/samples/adk_triaging_agent/README.md @@ -1,18 +1,39 @@ # ADK Issue Triaging Assistant -The ADK Issue Triaging Assistant is a Python-based agent designed to help manage and triage GitHub issues for the `google/adk-python` repository. It uses a large language model to analyze new and unlabelled issues, recommend appropriate labels based on a predefined set of rules, and apply them. +The ADK Issue Triaging Assistant is a Python-based agent designed to help manage and triage GitHub issues for the `google/adk-python` repository. It uses a large language model to analyze issues, recommend appropriate component labels, set issue types, and assign owners based on predefined rules. This agent can be operated in two distinct modes: an interactive mode for local use or as a fully automated GitHub Actions workflow. --- +## Triaging Workflow + +The agent performs different actions based on the issue state: + +| Condition | Actions | +|-----------|---------| +| Issue without component label | Add component label + Set issue type (Bug/Feature) | +| Issue with "planned" label but no assignee | Assign owner based on component label | +| Issue with "planned" label AND no component label | Add component label + Set type + Assign owner | + +### Component Labels +The agent can assign the following component labels, each mapped to an owner: +- `core`, `tools`, `mcp`, `eval`, `live`, `models`, `tracing`, `web`, `services`, `documentation`, `question`, `agent engine`, `a2a`, `bq` + +### Issue Types +Based on the issue content, the agent will set the issue type to: +- **Bug**: For bug reports +- **Feature**: For feature requests + +--- + ## Interactive Mode This mode allows you to run the agent locally to review its recommendations in real-time before any changes are made to your repository's issues. ### Features * **Web Interface**: The agent's interactive mode can be rendered in a web browser using the ADK's `adk web` command. -* **User Approval**: In interactive mode, the agent is instructed to ask for your confirmation before applying a label to a GitHub issue. +* **User Approval**: In interactive mode, the agent is instructed to ask for your confirmation before applying labels or assigning owners. ### Running in Interactive Mode To run the agent in interactive mode, first set the required environment variables. Then, execute the following command in your terminal: @@ -31,12 +52,19 @@ For automated, hands-off issue triaging, the agent can be integrated directly in ### Workflow Triggers The GitHub workflow is configured to run on specific triggers: -1. **Issue Events**: The workflow executes automatically whenever a new issue is `opened` or an existing one is `reopened`. +1. **New Issues (`opened`)**: When a new issue is created, the agent adds an appropriate component label and sets the issue type. + +2. **Planned Label Added (`labeled` with "planned")**: When an issue is labeled as "planned", the agent assigns an owner based on the component label. If the issue doesn't have a component label yet, the agent will also add one. + +3. **Scheduled Runs**: The workflow runs every 6 hours to process any issues that need triaging (either missing component labels or missing assignees for "planned" issues). -2. **Scheduled Runs**: The workflow also runs on a recurring schedule (every 6 hours) to process any unlabelled issues that may have been missed. +### Automated Actions +When running as part of the GitHub workflow, the agent operates non-interactively: +- **Component Labeling**: Automatically applies the most appropriate component label +- **Issue Type Setting**: Sets the issue type to Bug or Feature based on content +- **Owner Assignment**: Only assigns owners for issues marked as "planned" -### Automated Labeling -When running as part of the GitHub workflow, the agent operates non-interactively. It identifies the best label and applies it directly without requiring user approval. This behavior is configured by setting the `INTERACTIVE` environment variable to `0` in the workflow file. +This behavior is configured by setting the `INTERACTIVE` environment variable to `0` in the workflow file. ### Workflow Configuration The workflow is defined in a YAML file (`.github/workflows/triage.yml`). This file contains the steps to check out the code, set up the Python environment, install dependencies, and run the triaging script with the necessary environment variables and secrets. @@ -60,8 +88,8 @@ The following environment variables are required for the agent to connect to the * `GITHUB_TOKEN`: **(Required)** A GitHub Personal Access Token with `issues:write` permissions. Needed for both interactive and workflow modes. * `GOOGLE_API_KEY`: **(Required)** Your API key for the Gemini API. Needed for both interactive and workflow modes. -* `OWNER`: The GitHub organization or username that owns the repository (e.g., `google`). Needed for both modes. -* `REPO`: The name of the GitHub repository (e.g., `adk-python`). Needed for both modes. +* `OWNER`: The GitHub organization or username that owns the repository (e.g., `google`). In the workflow, this is automatically set from the repository context. +* `REPO`: The name of the GitHub repository (e.g., `adk-python`). In the workflow, this is automatically set from the repository context. * `INTERACTIVE`: Controls the agent's interaction mode. For the automated workflow, this is set to `0`. For interactive mode, it should be set to `1` or left unset. For local execution in interactive mode, you can place these variables in a `.env` file in the project's root directory. For the GitHub workflow, they should be configured as repository secrets. \ No newline at end of file diff --git a/contributing/samples/adk_triaging_agent/agent.py b/contributing/samples/adk_triaging_agent/agent.py index 167eb3a616..d3e653f1d0 100644 --- a/contributing/samples/adk_triaging_agent/agent.py +++ b/contributing/samples/adk_triaging_agent/agent.py @@ -78,22 +78,27 @@ APPROVAL_INSTRUCTION = "Only label them when the user approves the labeling!" -def list_planned_untriaged_issues(issue_count: int) -> dict[str, Any]: - """List planned issues without component labels (e.g., core, tools, etc.). +def list_untriaged_issues(issue_count: int) -> dict[str, Any]: + """List open issues that need triaging. + + Returns issues that need any of the following actions: + 1. Issues without component labels (need labeling + type setting) + 2. Issues with 'planned' label but no assignee (need owner assignment) Args: issue_count: number of issues to return Returns: The status of this request, with a list of issues when successful. + Each issue includes flags indicating what actions are needed. """ url = f"{GITHUB_BASE_URL}/search/issues" - query = f"repo:{OWNER}/{REPO} is:open is:issue label:planned" + query = f"repo:{OWNER}/{REPO} is:open is:issue" params = { "q": query, "sort": "created", "order": "desc", - "per_page": issue_count, + "per_page": 100, # Fetch more to filter "page": 1, } @@ -103,29 +108,46 @@ def list_planned_untriaged_issues(issue_count: int) -> dict[str, Any]: return error_response(f"Error: {e}") issues = response.get("items", []) - # Filter out issues that already have component labels component_labels = set(LABEL_TO_OWNER.keys()) untriaged_issues = [] for issue in issues: issue_labels = {label["name"] for label in issue.get("labels", [])} - # If the issue only has "planned" but no component labels, it's untriaged - if not (issue_labels & component_labels): + assignees = issue.get("assignees", []) + + existing_component_labels = issue_labels & component_labels + has_component = bool(existing_component_labels) + has_planned = "planned" in issue_labels + + # Determine what actions are needed + needs_component_label = not has_component + needs_owner = has_planned and not assignees + + # Include issue if it needs any action + if needs_component_label or needs_owner: + issue["has_planned_label"] = has_planned + issue["has_component_label"] = has_component + issue["existing_component_label"] = ( + list(existing_component_labels)[0] + if existing_component_labels + else None + ) + issue["needs_component_label"] = needs_component_label + issue["needs_owner"] = needs_owner untriaged_issues.append(issue) + if len(untriaged_issues) >= issue_count: + break return {"status": "success", "issues": untriaged_issues} -def add_label_and_owner_to_issue( - issue_number: int, label: str -) -> dict[str, Any]: - """Add the specified label and owner to the given issue number. +def add_label_to_issue(issue_number: int, label: str) -> dict[str, Any]: + """Add the specified component label to the given issue number. Args: issue_number: issue number of the GitHub issue. label: label to assign Returns: - The the status of this request, with the applied label and assigned owner - when successful. + The status of this request, with the applied label when successful. """ print(f"Attempting to add label '{label}' to issue #{issue_number}") if label not in LABEL_TO_OWNER: @@ -143,15 +165,38 @@ def add_label_and_owner_to_issue( except requests.exceptions.RequestException as e: return error_response(f"Error: {e}") + return { + "status": "success", + "message": response, + "applied_label": label, + } + + +def add_owner_to_issue(issue_number: int, label: str) -> dict[str, Any]: + """Assign an owner to the issue based on the component label. + + This should only be called for issues that have the 'planned' label. + + Args: + issue_number: issue number of the GitHub issue. + label: component label that determines the owner to assign + + Returns: + The status of this request, with the assigned owner when successful. + """ + print( + f"Attempting to assign owner for label '{label}' to issue #{issue_number}" + ) + if label not in LABEL_TO_OWNER: + return error_response( + f"Error: Label '{label}' is not a valid component label." + ) + owner = LABEL_TO_OWNER.get(label, None) if not owner: return { "status": "warning", - "message": ( - f"{response}\n\nLabel '{label}' does not have an owner. Will not" - " assign." - ), - "applied_label": label, + "message": f"Label '{label}' does not have an owner. Will not assign.", } assignee_url = ( @@ -167,7 +212,6 @@ def add_label_and_owner_to_issue( return { "status": "success", "message": response, - "applied_label": label, "assigned_owner": owner, } @@ -223,29 +267,46 @@ def change_issue_type(issue_number: int, issue_type: str) -> dict[str, Any]: - If it's about BigQuery integrations, label it with "bq". - If you can't find an appropriate labels for the issue, follow the previous instruction that starts with "IMPORTANT:". - Call the `add_label_and_owner_to_issue` tool to label the issue, which will also assign the issue to the owner of the label. + ## Triaging Workflow + + Each issue will have flags indicating what actions are needed: + - `needs_component_label`: true if the issue needs a component label + - `needs_owner`: true if the issue needs an owner assigned (has 'planned' label but no assignee) + + For each issue, perform ONLY the required actions based on the flags: + + 1. **If `needs_component_label` is true**: + - Use `add_label_to_issue` to add the appropriate component label + - Use `change_issue_type` to set the issue type: + - Bug report → "Bug" + - Feature request → "Feature" + - Otherwise → do not change the issue type + + 2. **If `needs_owner` is true**: + - Use `add_owner_to_issue` to assign an owner based on the component label + - Note: If the issue already has a component label (`has_component_label: true`), use that existing label to determine the owner - After you label the issue, call the `change_issue_type` tool to change the issue type: - - If the issue is a bug report, change the issue type to "Bug". - - If the issue is a feature request, change the issue type to "Feature". - - Otherwise, **do not change the issue type**. + Do NOT add a component label if `needs_component_label` is false. + Do NOT assign an owner if `needs_owner` is false. Response quality requirements: - Summarize the issue in your own words without leaving template placeholders (never output text like "[fill in later]"). - Justify the chosen label with a short explanation referencing the issue details. - - Mention the assigned owner when a label maps to one. + - Mention the assigned owner only when you actually assign one (i.e., when + the issue has the 'planned' label). - If no label is applied, clearly state why. Present the following in an easy to read format highlighting issue number and your label. - the issue summary in a few sentence - your label recommendation and justification - - the owner of the label if you assign the issue to an owner + - the owner of the label if you assign the issue to an owner (only for planned issues) """, tools=[ - list_planned_untriaged_issues, - add_label_and_owner_to_issue, + list_untriaged_issues, + add_label_to_issue, + add_owner_to_issue, change_issue_type, ], ) diff --git a/contributing/samples/adk_triaging_agent/main.py b/contributing/samples/adk_triaging_agent/main.py index f24302ac4b..3a2d4da570 100644 --- a/contributing/samples/adk_triaging_agent/main.py +++ b/contributing/samples/adk_triaging_agent/main.py @@ -46,24 +46,41 @@ async def fetch_specific_issue_details(issue_number: int): issue_data = get_request(url) labels = issue_data.get("labels", []) label_names = {label["name"] for label in labels} + assignees = issue_data.get("assignees", []) - # Check if issue has "planned" label but no component labels + # Check issue state component_labels = set(LABEL_TO_OWNER.keys()) has_planned = "planned" in label_names - has_component = bool(label_names & component_labels) + existing_component_labels = label_names & component_labels + has_component = bool(existing_component_labels) + has_assignee = len(assignees) > 0 - if has_planned and not has_component: - print(f"Issue #{issue_number} is planned but not triaged. Proceeding.") + # Determine what actions are needed + needs_component_label = not has_component + needs_owner = has_planned and not has_assignee + + if needs_component_label or needs_owner: + print( + f"Issue #{issue_number} needs triaging. " + f"needs_component_label={needs_component_label}, " + f"needs_owner={needs_owner}" + ) return { "number": issue_data["number"], "title": issue_data["title"], "body": issue_data.get("body", ""), + "has_planned_label": has_planned, + "has_component_label": has_component, + "existing_component_label": ( + list(existing_component_labels)[0] + if existing_component_labels + else None + ), + "needs_component_label": needs_component_label, + "needs_owner": needs_owner, } else: - print( - f"Issue #{issue_number} is already triaged or doesn't have" - " 'planned' label. Skipping." - ) + print(f"Issue #{issue_number} is already fully triaged. Skipping.") return None except requests.exceptions.RequestException as e: print(f"Error fetching issue #{issue_number}: {e}") @@ -127,25 +144,24 @@ async def main(): issue_title = ISSUE_TITLE or specific_issue["title"] issue_body = ISSUE_BODY or specific_issue["body"] + needs_component_label = specific_issue.get("needs_component_label", True) + needs_owner = specific_issue.get("needs_owner", False) + existing_component_label = specific_issue.get("existing_component_label") + prompt = ( - f"A GitHub issue #{issue_number} has been labeled as 'planned'." - f' Title: "{issue_title}"\nBody:' - f' "{issue_body}"\n\nBased on the rules, recommend an' - " appropriate component label and its justification." - " Then, use the 'add_label_and_owner_to_issue' tool to apply the" - " label directly to this issue. Only label it, do not" - " process any other issues." + f"Triage GitHub issue #{issue_number}.\n\n" + f'Title: "{issue_title}"\n' + f'Body: "{issue_body}"\n\n' + f"Issue state: needs_component_label={needs_component_label}, " + f"needs_owner={needs_owner}, " + f"existing_component_label={existing_component_label}" ) else: print(f"EVENT: Processing batch of issues (event: {EVENT_NAME}).") issue_count = parse_number_string(ISSUE_COUNT_TO_PROCESS, default_value=3) prompt = ( - "Please use the 'list_planned_untriaged_issues' tool to find the" - f" most recent {issue_count} planned issues that haven't been" - " triaged yet (i.e., issues with 'planned' label but no component" - " labels like 'core', 'tools', etc.). Then triage each of them by" - " applying appropriate component labels. If you cannot find any planned" - " issues, please don't try to triage any issues." + f"Please use 'list_untriaged_issues' to find {issue_count} issues that" + " need triaging, then triage each one according to your instructions." ) response = await call_agent_async(runner, USER_ID, session.id, prompt) From b807d62fe35a5de4e1efd37f38230edcf5f992ad Mon Sep 17 00:00:00 2001 From: Faraaz Ahmed Date: Tue, 2 Dec 2025 19:13:38 -0800 Subject: [PATCH 56/63] feat(bigquery): Add labels support to BigQueryToolConfig for job tracking and monitoring Merge https://github.com/google/adk-python/pull/3583 **Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.** ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** - Closes: #3582 **2. Or, if no issue exists, describe the change:** _If applicable, please follow the issue templates to provide as much detail as possible._ **Problem:** Currently, the BigQuery tool in ADK does not provide a way for developers to add custom labels to BigQuery jobs created by their agents. This makes it difficult to: Track and monitor BigQuery costs associated with specific agents or use cases Organize and filter BigQuery jobs in the Google Cloud Console Implement billing attribution and resource organization strategies Differentiate between jobs from different environments (dev, staging, production) While the tool automatically adds an internal adk-bigquery-tool label with the caller_id, there's no mechanism for users to add their own custom labels for tracking and monitoring purposes. **Solution:** Add a labels configuration field to BigQueryToolConfig that allows users to specify custom key-value pairs to be applied to all BigQuery jobs executed by the agent. The solution should: Configuration Option: Add an optional labels parameter to BigQueryToolConfig accepting a dictionary of string key-value pairs Validation: Ensure labels follow BigQuery's requirements (non-empty string keys, string values) Job Application: Automatically apply configured labels to all BigQuery jobs alongside the existing internal labels Documentation: Provide clear documentation on how to use labels for tracking and monitoring ### Testing Plan _Please describe the tests that you ran to verify your changes. This is required for all PRs that are not small documentation or typo fixes._ **Unit Tests:** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. _Please include a summary of passed `pytest` results._ ``` pytest tests/unittests/tools/bigquery/test_bigquery_tool_config.py -v --tb=line -W ignore::UserWarning ========================================= test session starts ========================================== platform darwin -- Python 3.11.14, pytest-9.0.1, pluggy-1.6.0 -- *****redacted****** cachedir: .pytest_cache rootdir: *****redacted****** configfile: pyproject.toml plugins: mock-3.15.1, anyio-4.11.0, xdist-3.8.0, langsmith-0.4.43, asyncio-1.3.0 asyncio: mode=Mode.AUTO, debug=False, asyncio_default_fixture_loop_scope=function, asyncio_default_test_loop_scope=function collected 14 items tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_experimental_warning PASSED [ 7%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_invalid_property PASSED [ 14%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_invalid_application_name PASSED [ 21%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_max_query_result_rows_default PASSED [ 28%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_max_query_result_rows_custom PASSED [ 35%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_valid_maximum_bytes_billed PASSED [ 42%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_invalid_maximum_bytes_billed PASSED [ 50%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_valid_labels PASSED [ 57%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_empty_labels PASSED [ 64%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_none_labels PASSED [ 71%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_invalid_labels_type PASSED [ 78%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_invalid_label_key_type PASSED [ 85%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_invalid_label_value_type PASSED [ 92%] tests/unittests/tools/bigquery/test_bigquery_tool_config.py::test_bigquery_tool_config_empty_label_key PASSED [100%] ==================================================================================================== 14 passed in 2.02s ==================================================================================================== ``` **Manual End-to-End (E2E) Tests:** _Please provide instructions on how to manually test your changes, including any necessary setup or configuration. Please provide logs or screenshots to help reviewers better understand the fix._ ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [x] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [x] I have manually tested my changes end-to-end. - [x] Any dependent changes have been merged and published in downstream modules. ### Additional context _Add any other context or screenshots about the feature request here._ COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3583 from Faraaz1994:feature/bq_label 0fd7fe6a3b1ee20a36f73562e425d007b8d7dc9d PiperOrigin-RevId: 839523588 --- src/google/adk/tools/bigquery/config.py | 20 ++ src/google/adk/tools/bigquery/query_tool.py | 5 +- .../bigquery/test_bigquery_query_tool.py | 237 ++++++++++++++++++ .../bigquery/test_bigquery_tool_config.py | 58 +++++ 4 files changed, 319 insertions(+), 1 deletion(-) diff --git a/src/google/adk/tools/bigquery/config.py b/src/google/adk/tools/bigquery/config.py index 7768f214ed..39b6a3d9b6 100644 --- a/src/google/adk/tools/bigquery/config.py +++ b/src/google/adk/tools/bigquery/config.py @@ -101,6 +101,16 @@ class BigQueryToolConfig(BaseModel): locations, see https://cloud.google.com/bigquery/docs/locations. """ + job_labels: Optional[dict[str, str]] = None + """Labels to apply to BigQuery jobs for tracking and monitoring. + + These labels will be added to all BigQuery jobs executed by the tools. + Labels must be key-value pairs where both keys and values are strings. + Labels can be used for billing, monitoring, and resource organization. + For more information about labels, see + https://cloud.google.com/bigquery/docs/labels-intro. + """ + @field_validator('maximum_bytes_billed') @classmethod def validate_maximum_bytes_billed(cls, v): @@ -121,3 +131,13 @@ def validate_application_name(cls, v): if v and ' ' in v: raise ValueError('Application name should not contain spaces.') return v + + @field_validator('job_labels') + @classmethod + def validate_job_labels(cls, v): + """Validate that job_labels keys are not empty.""" + if v is not None: + for key in v.keys(): + if not key: + raise ValueError('Label keys cannot be empty.') + return v diff --git a/src/google/adk/tools/bigquery/query_tool.py b/src/google/adk/tools/bigquery/query_tool.py index 666dc3c5a1..5bcd734e70 100644 --- a/src/google/adk/tools/bigquery/query_tool.py +++ b/src/google/adk/tools/bigquery/query_tool.py @@ -68,7 +68,10 @@ def _execute_sql( bq_connection_properties = [] # BigQuery job labels if applicable - bq_job_labels = {} + bq_job_labels = ( + settings.job_labels.copy() if settings and settings.job_labels else {} + ) + if caller_id: bq_job_labels["adk-bigquery-tool"] = caller_id if settings and settings.application_name: diff --git a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py index eef83a1f5e..1791100e1f 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_query_tool.py +++ b/tests/unittests/tools/bigquery/test_bigquery_query_tool.py @@ -1709,6 +1709,65 @@ def test_execute_sql_job_labels( } +@pytest.mark.parametrize( + ("write_mode", "dry_run", "query_call_count", "query_and_wait_call_count"), + [ + pytest.param(WriteMode.ALLOWED, False, 0, 1, id="write-allowed"), + pytest.param(WriteMode.ALLOWED, True, 1, 0, id="write-allowed-dry-run"), + pytest.param(WriteMode.BLOCKED, False, 1, 1, id="write-blocked"), + pytest.param(WriteMode.BLOCKED, True, 2, 0, id="write-blocked-dry-run"), + pytest.param(WriteMode.PROTECTED, False, 2, 1, id="write-protected"), + pytest.param( + WriteMode.PROTECTED, True, 3, 0, id="write-protected-dry-run" + ), + ], +) +def test_execute_sql_user_job_labels_augment_internal_labels( + write_mode, dry_run, query_call_count, query_and_wait_call_count +): + """Test execute_sql tool augments user job_labels with internal labels.""" + project = "my_project" + query = "SELECT 123 AS num" + statement_type = "SELECT" + credentials = mock.create_autospec(Credentials, instance=True) + user_labels = {"environment": "test", "team": "data"} + tool_settings = BigQueryToolConfig( + write_mode=write_mode, + job_labels=user_labels, + ) + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = None + + with mock.patch.object(bigquery, "Client", autospec=True) as Client: + bq_client = Client.return_value + + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.statement_type = statement_type + bq_client.query.return_value = query_job + + query_tool.execute_sql( + project, + query, + credentials, + tool_settings, + tool_context, + dry_run=dry_run, + ) + + assert bq_client.query.call_count == query_call_count + assert bq_client.query_and_wait.call_count == query_and_wait_call_count + # Build expected labels from user_labels + internal label + expected_labels = {**user_labels, "adk-bigquery-tool": "execute_sql"} + for call_args_list in [ + bq_client.query.call_args_list, + bq_client.query_and_wait.call_args_list, + ]: + for call_args in call_args_list: + _, mock_kwargs = call_args + # Verify user labels are preserved and internal label is added + assert mock_kwargs["job_config"].labels == expected_labels + + @pytest.mark.parametrize( ("tool_call", "expected_tool_label"), [ @@ -1850,6 +1909,94 @@ def test_ml_tool_job_labels_w_application_name(tool_call, expected_tool_label): assert mock_kwargs["job_config"].labels == expected_labels +@pytest.mark.parametrize( + ("tool_call", "expected_labels"), + [ + pytest.param( + lambda tool_context: query_tool.forecast( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + timestamp_col="ts_col", + data_col="data_col", + credentials=mock.create_autospec(Credentials, instance=True), + settings=BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels={"environment": "prod", "app": "forecaster"}, + ), + tool_context=tool_context, + ), + { + "environment": "prod", + "app": "forecaster", + "adk-bigquery-tool": "forecast", + }, + id="forecast", + ), + pytest.param( + lambda tool_context: query_tool.analyze_contribution( + project_id="test-project", + input_data="test-dataset.test-table", + dimension_id_cols=["dim1", "dim2"], + contribution_metric="SUM(metric)", + is_test_col="is_test", + credentials=mock.create_autospec(Credentials, instance=True), + settings=BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels={"environment": "prod", "app": "analyzer"}, + ), + tool_context=tool_context, + ), + { + "environment": "prod", + "app": "analyzer", + "adk-bigquery-tool": "analyze_contribution", + }, + id="analyze-contribution", + ), + pytest.param( + lambda tool_context: query_tool.detect_anomalies( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + times_series_timestamp_col="ts_timestamp", + times_series_data_col="ts_data", + credentials=mock.create_autospec(Credentials, instance=True), + settings=BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels={"environment": "prod", "app": "detector"}, + ), + tool_context=tool_context, + ), + { + "environment": "prod", + "app": "detector", + "adk-bigquery-tool": "detect_anomalies", + }, + id="detect-anomalies", + ), + ], +) +def test_ml_tool_user_job_labels_augment_internal_labels( + tool_call, expected_labels +): + """Test ML tools augment user job_labels with internal labels.""" + + with mock.patch.object(bigquery, "Client", autospec=True) as Client: + bq_client = Client.return_value + + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = None + tool_call(tool_context) + + for call_args_list in [ + bq_client.query.call_args_list, + bq_client.query_and_wait.call_args_list, + ]: + for call_args in call_args_list: + _, mock_kwargs = call_args + # Verify user labels are preserved and internal label is added + assert mock_kwargs["job_config"].labels == expected_labels + + def test_execute_sql_max_rows_config(): """Test execute_sql tool respects max_query_result_rows from config.""" project = "my_project" @@ -2014,3 +2161,93 @@ def test_tool_call_doesnt_change_global_settings(tool_call): # Test settings write mode after assert settings.write_mode == WriteMode.ALLOWED + + +@pytest.mark.parametrize( + ("tool_call",), + [ + pytest.param( + lambda settings, tool_context: query_tool.execute_sql( + project_id="test-project", + query="SELECT * FROM `test-dataset.test-table`", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="execute-sql", + ), + pytest.param( + lambda settings, tool_context: query_tool.forecast( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + timestamp_col="ts_col", + data_col="data_col", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="forecast", + ), + pytest.param( + lambda settings, tool_context: query_tool.analyze_contribution( + project_id="test-project", + input_data="test-dataset.test-table", + dimension_id_cols=["dim1", "dim2"], + contribution_metric="SUM(metric)", + is_test_col="is_test", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="analyze-contribution", + ), + pytest.param( + lambda settings, tool_context: query_tool.detect_anomalies( + project_id="test-project", + history_data="SELECT * FROM `test-dataset.test-table`", + times_series_timestamp_col="ts_timestamp", + times_series_data_col="ts_data", + credentials=mock.create_autospec(Credentials, instance=True), + settings=settings, + tool_context=tool_context, + ), + id="detect-anomalies", + ), + ], +) +def test_tool_call_doesnt_mutate_job_labels(tool_call): + """Test query tools don't mutate job_labels in global settings.""" + original_labels = {"environment": "test", "team": "data"} + settings = BigQueryToolConfig( + write_mode=WriteMode.ALLOWED, + job_labels=original_labels.copy(), + ) + tool_context = mock.create_autospec(ToolContext, instance=True) + tool_context.state.get.return_value = ( + "test-bq-session-id", + "_anonymous_dataset", + ) + + with mock.patch("google.cloud.bigquery.Client", autospec=False) as Client: + # The mock instance + bq_client = Client.return_value + + # Simulate the result of query API + query_job = mock.create_autospec(bigquery.QueryJob) + query_job.destination.dataset_id = "_anonymous_dataset" + bq_client.query.return_value = query_job + bq_client.query_and_wait.return_value = [] + + # Test job_labels before + assert settings.job_labels == original_labels + assert "adk-bigquery-tool" not in settings.job_labels + + # Call the tool + result = tool_call(settings, tool_context) + + # Test successful execution of the tool + assert result == {"status": "SUCCESS", "rows": []} + + # Test job_labels remain unchanged after tool call + assert settings.job_labels == original_labels + assert "adk-bigquery-tool" not in settings.job_labels diff --git a/tests/unittests/tools/bigquery/test_bigquery_tool_config.py b/tests/unittests/tools/bigquery/test_bigquery_tool_config.py index 5854c97797..072ccea7d0 100644 --- a/tests/unittests/tools/bigquery/test_bigquery_tool_config.py +++ b/tests/unittests/tools/bigquery/test_bigquery_tool_config.py @@ -77,3 +77,61 @@ def test_bigquery_tool_config_invalid_maximum_bytes_billed(): ), ): BigQueryToolConfig(maximum_bytes_billed=10_485_759) + + +@pytest.mark.parametrize( + "labels", + [ + pytest.param( + {"environment": "test", "team": "data"}, + id="valid-labels", + ), + pytest.param( + {}, + id="empty-labels", + ), + pytest.param( + None, + id="none-labels", + ), + ], +) +def test_bigquery_tool_config_valid_labels(labels): + """Test BigQueryToolConfig accepts valid labels.""" + with pytest.warns(UserWarning): + config = BigQueryToolConfig(job_labels=labels) + assert config.job_labels == labels + + +@pytest.mark.parametrize( + ("labels", "message"), + [ + pytest.param( + "invalid", + "Input should be a valid dictionary", + id="invalid-type", + ), + pytest.param( + {123: "value"}, + "Input should be a valid string", + id="non-str-key", + ), + pytest.param( + {"key": 123}, + "Input should be a valid string", + id="non-str-value", + ), + pytest.param( + {"": "value"}, + "Label keys cannot be empty", + id="empty-label-key", + ), + ], +) +def test_bigquery_tool_config_invalid_labels(labels, message): + """Test BigQueryToolConfig raises an exception with invalid labels.""" + with pytest.raises( + ValueError, + match=message, + ): + BigQueryToolConfig(job_labels=labels) From b638a48357c9fd5089b4d59dcd3f2102d1d2b3b8 Mon Sep 17 00:00:00 2001 From: Kathy Wu Date: Tue, 2 Dec 2025 20:16:58 -0800 Subject: [PATCH 57/63] fix: Update API Registry Toolset to prod cloudapiregistry URL now that it is available Co-authored-by: Kathy Wu PiperOrigin-RevId: 839547174 --- src/google/adk/tools/api_registry.py | 3 +-- tests/unittests/tools/test_api_registry.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/google/adk/tools/api_registry.py b/src/google/adk/tools/api_registry.py index 941c6f0d5c..e3f0076404 100644 --- a/src/google/adk/tools/api_registry.py +++ b/src/google/adk/tools/api_registry.py @@ -29,8 +29,7 @@ from .mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams from .mcp_tool.mcp_toolset import McpToolset -# TODO(wukathy): Update to prod URL once it is available. -API_REGISTRY_URL = "https://staging-cloudapiregistry.sandbox.googleapis.com" +API_REGISTRY_URL = "https://cloudapiregistry.googleapis.com" class ApiRegistry: diff --git a/tests/unittests/tools/test_api_registry.py b/tests/unittests/tools/test_api_registry.py index d1131eed0b..df54786049 100644 --- a/tests/unittests/tools/test_api_registry.py +++ b/tests/unittests/tools/test_api_registry.py @@ -73,7 +73,7 @@ def test_init_success(self, MockHttpClient): self.assertIn("test-mcp-server-2", api_registry._mcp_servers) self.assertIn("test-mcp-server-no-url", api_registry._mcp_servers) mock_client_instance.get.assert_called_once_with( - f"https://staging-cloudapiregistry.sandbox.googleapis.com/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers", + f"https://cloudapiregistry.googleapis.com/v1beta/projects/{self.project_id}/locations/{self.location}/mcpServers", headers={ "Authorization": "Bearer mock_token", "Content-Type": "application/json", From 76dc169a83065ee10283d54253164e04f712ea5f Mon Sep 17 00:00:00 2001 From: Rohit Yanamadala Date: Wed, 3 Dec 2025 00:14:10 -0800 Subject: [PATCH 58/63] fix: Add editLimit parameter to GraphQL query Merge https://github.com/google/adk-python/pull/3771 Co-authored-by: Xuan Yang COPYBARA_INTEGRATE_REVIEW=https://github.com/google/adk-python/pull/3771 from google:ryanaiagent-patch-1 a169e728af223594febc39299e046f2a195d606a PiperOrigin-RevId: 839620767 --- contributing/samples/adk_stale_agent/agent.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/contributing/samples/adk_stale_agent/agent.py b/contributing/samples/adk_stale_agent/agent.py index 5235e0352f..8769adc193 100644 --- a/contributing/samples/adk_stale_agent/agent.py +++ b/contributing/samples/adk_stale_agent/agent.py @@ -135,13 +135,13 @@ def _fetch_graphql_data(item_number: int) -> Dict[str, Any]: RequestException: If the GraphQL query returns errors or the issue is not found. """ query = """ - query($owner: String!, $name: String!, $number: Int!, $commentLimit: Int!, $timelineLimit: Int!) { + query($owner: String!, $name: String!, $number: Int!, $commentLimit: Int!, $timelineLimit: Int!, $editLimit: Int!) { repository(owner: $owner, name: $name) { issue(number: $number) { author { login } createdAt labels(first: 20) { nodes { name } } - + comments(last: $commentLimit) { nodes { author { login } @@ -150,14 +150,14 @@ def _fetch_graphql_data(item_number: int) -> Dict[str, Any]: lastEditedAt } } - + userContentEdits(last: $editLimit) { nodes { editor { login } editedAt } } - + timelineItems(itemTypes: [LABELED_EVENT, RENAMED_TITLE_EVENT, REOPENED_EVENT], last: $timelineLimit) { nodes { __typename From e02b9fb608f645cdcf8300ffe4e64627df10ba28 Mon Sep 17 00:00:00 2001 From: George Weale Date: Wed, 3 Dec 2025 09:46:42 -0800 Subject: [PATCH 59/63] fix: Add a warning when deploying with the ADK Web UI enabled The warning message shows that ADK Web is for development purposes only and should not be used in production, as it has access to all data. This warning is displayed when the `--with-ui` flag is used with `adk deploy` and `adk deploy to-gke` Co-authored-by: George Weale PiperOrigin-RevId: 839795361 --- src/google/adk/cli/cli_tools_click.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index 6c3e7b98a9..c4446278b4 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -110,6 +110,18 @@ def parse_args(self, ctx, args): logger = logging.getLogger("google_adk." + __name__) +_ADK_WEB_WARNING = ( + "ADK Web is for development purposes. It has access to all data and" + " should not be used in production." +) + + +def _warn_if_with_ui(with_ui: bool) -> None: + """Warn when deploying with the developer UI enabled.""" + if with_ui: + click.secho(f"WARNING: {_ADK_WEB_WARNING}", fg="yellow", err=True) + + @click.group(context_settings={"max_content_width": 240}) @click.version_option(version.__version__) def main(): @@ -1429,6 +1441,8 @@ def cli_deploy_cloud_run( err=True, ) + _warn_if_with_ui(with_ui) + session_service_uri = session_service_uri or session_db_url artifact_service_uri = artifact_service_uri or artifact_storage_uri @@ -1848,6 +1862,7 @@ def cli_deploy_gke( --cluster_name=[cluster_name] path/to/my_agent """ try: + _warn_if_with_ui(with_ui) cli_deploy.to_gke( agent_folder=agent, project=project, From 8c9105bf14f57606a73753654922fe26f584dff6 Mon Sep 17 00:00:00 2001 From: George Weale Date: Wed, 3 Dec 2025 09:55:52 -0800 Subject: [PATCH 60/63] chore: Drop Python 3.9 support, set minimum to Python 3.10 Co-authored-by: George Weale PiperOrigin-RevId: 839799108 --- .github/workflows/python-unit-tests.yml | 16 ++--- .../sample-output/alembic.ini | 2 +- contributing/samples/telemetry/main.py | 37 +++++----- llms-full.txt | 6 +- .../adk/a2a/converters/part_converter.py | 16 +---- .../adk/a2a/converters/request_converter.py | 15 +--- .../adk/a2a/executor/a2a_agent_executor.py | 36 ++++------ .../adk/a2a/utils/agent_card_builder.py | 20 ++---- src/google/adk/a2a/utils/agent_to_a2a.py | 23 ++---- src/google/adk/agents/__init__.py | 18 +---- .../adk/agents/mcp_instruction_provider.py | 18 +---- src/google/adk/agents/parallel_agent.py | 70 ++++++++---------- src/google/adk/agents/remote_a2a_agent.py | 43 +++++------ src/google/adk/cli/fast_api.py | 27 +++---- src/google/adk/tools/crewai_tool.py | 13 +--- src/google/adk/tools/mcp_tool/__init__.py | 12 +--- .../adk/tools/mcp_tool/mcp_session_manager.py | 21 ++---- src/google/adk/tools/mcp_tool/mcp_tool.py | 27 ++----- src/google/adk/tools/mcp_tool/mcp_toolset.py | 19 +---- src/google/adk/utils/context_utils.py | 29 +------- .../a2a/converters/test_event_converter.py | 61 ++++++---------- .../a2a/converters/test_part_converter.py | 37 +++------- .../a2a/converters/test_request_converter.py | 27 ++----- tests/unittests/a2a/converters/test_utils.py | 28 ++------ .../a2a/executor/test_a2a_agent_executor.py | 41 ++++------- .../executor/test_task_result_aggregator.py | 33 +++------ .../a2a/utils/test_agent_card_builder.py | 71 +++++++----------- .../unittests/a2a/utils/test_agent_to_a2a.py | 45 ++++-------- .../agents/test_mcp_instruction_provider.py | 22 +----- .../unittests/agents/test_remote_a2a_agent.py | 72 ++++++------------- tests/unittests/cli/test_fast_api.py | 11 --- .../evaluation/test_local_eval_service.py | 3 - .../plugins/test_reflect_retry_tool_plugin.py | 6 +- tests/unittests/telemetry/test_functional.py | 2 +- .../computer_use/test_computer_use_tool.py | 2 +- .../mcp_tool/test_mcp_session_manager.py | 44 ++---------- .../unittests/tools/mcp_tool/test_mcp_tool.py | 39 ++-------- .../tools/mcp_tool/test_mcp_toolset.py | 46 +++--------- .../tools/retrieval/test_files_retrieval.py | 4 -- tests/unittests/tools/test_mcp_toolset.py | 21 +----- 40 files changed, 279 insertions(+), 804 deletions(-) diff --git a/.github/workflows/python-unit-tests.yml b/.github/workflows/python-unit-tests.yml index 8f8f46e953..3fc6bd943f 100644 --- a/.github/workflows/python-unit-tests.yml +++ b/.github/workflows/python-unit-tests.yml @@ -25,7 +25,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.10", "3.11", "3.12", "3.13"] steps: - name: Checkout code @@ -48,14 +48,6 @@ jobs: - name: Run unit tests with pytest run: | source .venv/bin/activate - if [[ "${{ matrix.python-version }}" == "3.9" ]]; then - pytest tests/unittests \ - --ignore=tests/unittests/a2a \ - --ignore=tests/unittests/tools/mcp_tool \ - --ignore=tests/unittests/artifacts/test_artifact_service.py \ - --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py - else - pytest tests/unittests \ - --ignore=tests/unittests/artifacts/test_artifact_service.py \ - --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py - fi \ No newline at end of file + pytest tests/unittests \ + --ignore=tests/unittests/artifacts/test_artifact_service.py \ + --ignore=tests/unittests/tools/google_api_tool/test_googleapi_to_openapi_converter.py \ No newline at end of file diff --git a/contributing/samples/migrate_session_db/sample-output/alembic.ini b/contributing/samples/migrate_session_db/sample-output/alembic.ini index 6405320948..e346ee8ac6 100644 --- a/contributing/samples/migrate_session_db/sample-output/alembic.ini +++ b/contributing/samples/migrate_session_db/sample-output/alembic.ini @@ -21,7 +21,7 @@ prepend_sys_path = . # timezone to use when rendering the date within the migration file # as well as the filename. -# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# If specified, requires the python>=3.10 and tzdata library. # Any required deps can installed by adding `alembic[tz]` to the pip requirements # string value is passed to ZoneInfo() # leave blank for localtime diff --git a/contributing/samples/telemetry/main.py b/contributing/samples/telemetry/main.py index e580060dc4..c6e05f0f62 100755 --- a/contributing/samples/telemetry/main.py +++ b/contributing/samples/telemetry/main.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +from contextlib import aclosing import os import time @@ -46,19 +47,16 @@ async def run_prompt(session: Session, new_message: str): role='user', parts=[types.Part.from_text(text=new_message)] ) print('** User says:', content.model_dump(exclude_none=True)) - # TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is - # no longer supported. - agen = runner.run_async( - user_id=user_id_1, - session_id=session.id, - new_message=content, - ) - try: + async with aclosing( + runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ) + ) as agen: async for event in agen: if event.content.parts and event.content.parts[0].text: print(f'** {event.author}: {event.content.parts[0].text}') - finally: - await agen.aclose() async def run_prompt_bytes(session: Session, new_message: str): content = types.Content( @@ -70,20 +68,17 @@ async def run_prompt_bytes(session: Session, new_message: str): ], ) print('** User says:', content.model_dump(exclude_none=True)) - # TODO - migrate try...finally to contextlib.aclosing after Python 3.9 is - # no longer supported. - agen = runner.run_async( - user_id=user_id_1, - session_id=session.id, - new_message=content, - run_config=RunConfig(save_input_blobs_as_artifacts=True), - ) - try: + async with aclosing( + runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + run_config=RunConfig(save_input_blobs_as_artifacts=True), + ) + ) as agen: async for event in agen: if event.content.parts and event.content.parts[0].text: print(f'** {event.author}: {event.content.parts[0].text}') - finally: - await agen.aclose() start_time = time.time() print('Start time:', start_time) diff --git a/llms-full.txt b/llms-full.txt index 4c744512e4..b84e9496ee 100644 --- a/llms-full.txt +++ b/llms-full.txt @@ -5620,7 +5620,7 @@ pip install google-cloud-aiplatform[adk,agent_engines] ``` !!!info - Agent Engine only supported Python version >=3.9 and <=3.12. + Agent Engine only supported Python version >=3.10 and <=3.12. ### Initialization @@ -8073,7 +8073,7 @@ setting up a basic agent with multiple tools, and running it locally either in t This quickstart assumes a local IDE (VS Code, PyCharm, IntelliJ IDEA, etc.) -with Python 3.9+ or Java 17+ and terminal access. This method runs the +with Python 3.10+ or Java 17+ and terminal access. This method runs the application entirely on your machine and is recommended for internal development. ## 1. Set up Environment & Install ADK {#venv-install} @@ -16475,7 +16475,7 @@ This guide covers two primary integration patterns: Before you begin, ensure you have the following set up: * **Set up ADK:** Follow the standard ADK [setup instructions](../get-started/quickstart.md/#venv-install) in the quickstart. -* **Install/update Python/Java:** MCP requires Python version of 3.9 or higher for Python or Java 17+. +* **Install/update Python/Java:** MCP requires Python version of 3.10 or higher for Python or Java 17+. * **Setup Node.js and npx:** **(Python only)** Many community MCP servers are distributed as Node.js packages and run using `npx`. Install Node.js (which includes npx) if you haven't already. For details, see [https://nodejs.org/en](https://nodejs.org/en). * **Verify Installations:** **(Python only)** Confirm `adk` and `npx` are in your PATH within the activated virtual environment: diff --git a/src/google/adk/a2a/converters/part_converter.py b/src/google/adk/a2a/converters/part_converter.py index a21042cc10..dfe6f4a0a2 100644 --- a/src/google/adk/a2a/converters/part_converter.py +++ b/src/google/adk/a2a/converters/part_converter.py @@ -26,23 +26,11 @@ from typing import Optional from typing import Union -from .utils import _get_adk_metadata_key - -try: - from a2a import types as a2a_types -except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - 'A2A requires Python 3.10 or above. Please upgrade your Python version.' - ) from e - else: - raise e - +from a2a import types as a2a_types from google.genai import types as genai_types from ..experimental import a2a_experimental +from .utils import _get_adk_metadata_key logger = logging.getLogger('google_adk.' + __name__) diff --git a/src/google/adk/a2a/converters/request_converter.py b/src/google/adk/a2a/converters/request_converter.py index 39db41dac6..1746ec0bca 100644 --- a/src/google/adk/a2a/converters/request_converter.py +++ b/src/google/adk/a2a/converters/request_converter.py @@ -15,23 +15,12 @@ from __future__ import annotations from collections.abc import Callable -import sys from typing import Any from typing import Optional -from pydantic import BaseModel - -try: - from a2a.server.agent_execution import RequestContext -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - 'A2A requires Python 3.10 or above. Please upgrade your Python version.' - ) from e - else: - raise e - +from a2a.server.agent_execution import RequestContext from google.genai import types as genai_types +from pydantic import BaseModel from ...runners import RunConfig from ..experimental import a2a_experimental diff --git a/src/google/adk/a2a/executor/a2a_agent_executor.py b/src/google/adk/a2a/executor/a2a_agent_executor.py index 608a818864..b6880aaa5c 100644 --- a/src/google/adk/a2a/executor/a2a_agent_executor.py +++ b/src/google/adk/a2a/executor/a2a_agent_executor.py @@ -23,34 +23,22 @@ from typing import Optional import uuid -from ...utils.context_utils import Aclosing - -try: - from a2a.server.agent_execution import AgentExecutor - from a2a.server.agent_execution.context import RequestContext - from a2a.server.events.event_queue import EventQueue - from a2a.types import Artifact - from a2a.types import Message - from a2a.types import Role - from a2a.types import TaskArtifactUpdateEvent - from a2a.types import TaskState - from a2a.types import TaskStatus - from a2a.types import TaskStatusUpdateEvent - from a2a.types import TextPart - -except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - 'A2A requires Python 3.10 or above. Please upgrade your Python version.' - ) from e - else: - raise e +from a2a.server.agent_execution import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import Artifact +from a2a.types import Message +from a2a.types import Role +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart from google.adk.runners import Runner from pydantic import BaseModel from typing_extensions import override +from ...utils.context_utils import Aclosing from ..converters.event_converter import AdkEventToA2AEventsConverter from ..converters.event_converter import convert_event_to_a2a_events from ..converters.part_converter import A2APartToGenAIPartConverter diff --git a/src/google/adk/a2a/utils/agent_card_builder.py b/src/google/adk/a2a/utils/agent_card_builder.py index aa7f657f99..c007870931 100644 --- a/src/google/adk/a2a/utils/agent_card_builder.py +++ b/src/google/adk/a2a/utils/agent_card_builder.py @@ -15,25 +15,15 @@ from __future__ import annotations import re -import sys from typing import Dict from typing import List from typing import Optional -try: - from a2a.types import AgentCapabilities - from a2a.types import AgentCard - from a2a.types import AgentProvider - from a2a.types import AgentSkill - from a2a.types import SecurityScheme -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - 'A2A requires Python 3.10 or above. Please upgrade your Python version.' - ) from e - else: - raise e - +from a2a.types import AgentCapabilities +from a2a.types import AgentCard +from a2a.types import AgentProvider +from a2a.types import AgentSkill +from a2a.types import SecurityScheme from ...agents.base_agent import BaseAgent from ...agents.llm_agent import LlmAgent diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py index 72a2292fb3..1a1ba35618 100644 --- a/src/google/adk/a2a/utils/agent_to_a2a.py +++ b/src/google/adk/a2a/utils/agent_to_a2a.py @@ -15,30 +15,18 @@ from __future__ import annotations import logging -import sys - -try: - from a2a.server.apps import A2AStarletteApplication - from a2a.server.request_handlers import DefaultRequestHandler - from a2a.server.tasks import InMemoryTaskStore - from a2a.types import AgentCard -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - "A2A requires Python 3.10 or above. Please upgrade your Python version." - ) from e - else: - raise e - from typing import Optional from typing import Union +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCard from starlette.applications import Starlette from ...agents.base_agent import BaseAgent from ...artifacts.in_memory_artifact_service import InMemoryArtifactService from ...auth.credential_service.in_memory_credential_service import InMemoryCredentialService -from ...cli.utils.logs import setup_adk_logger from ...memory.in_memory_memory_service import InMemoryMemoryService from ...runners import Runner from ...sessions.in_memory_session_service import InMemorySessionService @@ -117,7 +105,8 @@ def to_a2a( app = to_a2a(agent, agent_card=my_custom_agent_card) """ # Set up ADK logging to ensure logs are visible when using uvicorn directly - setup_adk_logger(logging.INFO) + adk_logger = logging.getLogger("google_adk") + adk_logger.setLevel(logging.INFO) async def create_runner() -> Runner: """Create a runner for the agent.""" diff --git a/src/google/adk/agents/__init__.py b/src/google/adk/agents/__init__.py index 5710a21b7f..b5f8e88cde 100644 --- a/src/google/adk/agents/__init__.py +++ b/src/google/adk/agents/__init__.py @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging -import sys - from .base_agent import BaseAgent from .invocation_context import InvocationContext from .live_request_queue import LiveRequest @@ -22,6 +19,7 @@ from .llm_agent import Agent from .llm_agent import LlmAgent from .loop_agent import LoopAgent +from .mcp_instruction_provider import McpInstructionProvider from .parallel_agent import ParallelAgent from .run_config import RunConfig from .sequential_agent import SequentialAgent @@ -31,6 +29,7 @@ 'BaseAgent', 'LlmAgent', 'LoopAgent', + 'McpInstructionProvider', 'ParallelAgent', 'SequentialAgent', 'InvocationContext', @@ -38,16 +37,3 @@ 'LiveRequestQueue', 'RunConfig', ] - -if sys.version_info < (3, 10): - logger = logging.getLogger('google_adk.' + __name__) - logger.warning( - 'MCP requires Python 3.10 or above. Please upgrade your Python' - ' version in order to use it.' - ) -else: - from .mcp_instruction_provider import McpInstructionProvider - - __all__.extend([ - 'McpInstructionProvider', - ]) diff --git a/src/google/adk/agents/mcp_instruction_provider.py b/src/google/adk/agents/mcp_instruction_provider.py index e9f40663c9..20896a7a04 100644 --- a/src/google/adk/agents/mcp_instruction_provider.py +++ b/src/google/adk/agents/mcp_instruction_provider.py @@ -22,24 +22,12 @@ from typing import Dict from typing import TextIO +from mcp import types + +from ..tools.mcp_tool.mcp_session_manager import MCPSessionManager from .llm_agent import InstructionProvider from .readonly_context import ReadonlyContext -# Attempt to import MCP Session Manager from the MCP library, and hints user to -# upgrade their Python version to 3.10 if it fails. -try: - from mcp import types - - from ..tools.mcp_tool.mcp_session_manager import MCPSessionManager -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - "MCP Session Manager requires Python 3.10 or above. Please upgrade" - " your Python version." - ) from e - else: - raise e - class McpInstructionProvider(InstructionProvider): """Fetches agent instructions from an MCP server.""" diff --git a/src/google/adk/agents/parallel_agent.py b/src/google/adk/agents/parallel_agent.py index f7270a75c9..09e65a67a4 100644 --- a/src/google/adk/agents/parallel_agent.py +++ b/src/google/adk/agents/parallel_agent.py @@ -48,34 +48,13 @@ def _create_branch_ctx_for_sub_agent( return invocation_context -# TODO - remove once Python <3.11 is no longer supported. -async def _merge_agent_run_pre_3_11( +async def _merge_agent_run( agent_runs: list[AsyncGenerator[Event, None]], ) -> AsyncGenerator[Event, None]: - """Merges the agent run event generator. - This version works in Python 3.9 and 3.10 and uses custom replacement for - asyncio.TaskGroup for tasks cancellation and exception handling. - - This implementation guarantees for each agent, it won't move on until the - generated event is processed by upstream runner. - - Args: - agent_runs: A list of async generators that yield events from each agent. - - Yields: - Event: The next event from the merged generator. - """ + """Merges agent runs using asyncio.TaskGroup on Python 3.11+.""" sentinel = object() queue = asyncio.Queue() - def propagate_exceptions(tasks): - # Propagate exceptions and errors from tasks. - for task in tasks: - if task.done(): - # Ignore the result (None) of correctly finished tasks and re-raise - # exceptions and errors. - task.result() - # Agents are processed in parallel. # Events for each agent are put on queue sequentially. async def process_an_agent(events_for_one_agent): @@ -89,39 +68,34 @@ async def process_an_agent(events_for_one_agent): # Mark agent as finished. await queue.put((sentinel, None)) - tasks = [] - try: + async with asyncio.TaskGroup() as tg: for events_for_one_agent in agent_runs: - tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent))) + tg.create_task(process_an_agent(events_for_one_agent)) sentinel_count = 0 # Run until all agents finished processing. while sentinel_count < len(agent_runs): - propagate_exceptions(tasks) event, resume_signal = await queue.get() # Agent finished processing. if event is sentinel: sentinel_count += 1 else: yield event - # Signal to agent that event has been processed by runner and it can - # continue now. + # Signal to agent that it should generate next event. resume_signal.set() - finally: - for task in tasks: - task.cancel() -async def _merge_agent_run( +# TODO - remove once Python <3.11 is no longer supported. +async def _merge_agent_run_pre_3_11( agent_runs: list[AsyncGenerator[Event, None]], ) -> AsyncGenerator[Event, None]: - """Merges the agent run event generator. + """Merges agent runs for Python 3.10 without asyncio.TaskGroup. - This implementation guarantees for each agent, it won't move on until the - generated event is processed by upstream runner. + Uses custom cancellation and exception handling to mirror TaskGroup + semantics. Each agent waits until the runner processes emitted events. Args: - agent_runs: A list of async generators that yield events from each agent. + agent_runs: Async generators that yield events from each agent. Yields: Event: The next event from the merged generator. @@ -129,6 +103,14 @@ async def _merge_agent_run( sentinel = object() queue = asyncio.Queue() + def propagate_exceptions(tasks): + # Propagate exceptions and errors from tasks. + for task in tasks: + if task.done(): + # Ignore the result (None) of correctly finished tasks and re-raise + # exceptions and errors. + task.result() + # Agents are processed in parallel. # Events for each agent are put on queue sequentially. async def process_an_agent(events_for_one_agent): @@ -142,21 +124,27 @@ async def process_an_agent(events_for_one_agent): # Mark agent as finished. await queue.put((sentinel, None)) - async with asyncio.TaskGroup() as tg: + tasks = [] + try: for events_for_one_agent in agent_runs: - tg.create_task(process_an_agent(events_for_one_agent)) + tasks.append(asyncio.create_task(process_an_agent(events_for_one_agent))) sentinel_count = 0 # Run until all agents finished processing. while sentinel_count < len(agent_runs): + propagate_exceptions(tasks) event, resume_signal = await queue.get() # Agent finished processing. if event is sentinel: sentinel_count += 1 else: yield event - # Signal to agent that it should generate next event. + # Signal to agent that event has been processed by runner and it can + # continue now. resume_signal.set() + finally: + for task in tasks: + task.cancel() class ParallelAgent(BaseAgent): @@ -195,13 +183,11 @@ async def _run_async_impl( pause_invocation = False try: - # TODO remove if once Python <3.11 is no longer supported. merge_func = ( _merge_agent_run if sys.version_info >= (3, 11) else _merge_agent_run_pre_3_11 ) - async with Aclosing(merge_func(agent_runs)) as agen: async for event in agen: yield event diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index 5d42730937..8d133060ec 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -26,30 +26,22 @@ from urllib.parse import urlparse import uuid -try: - from a2a.client import Client as A2AClient - from a2a.client import ClientEvent as A2AClientEvent - from a2a.client.card_resolver import A2ACardResolver - from a2a.client.client import ClientConfig as A2AClientConfig - from a2a.client.client_factory import ClientFactory as A2AClientFactory - from a2a.client.errors import A2AClientError - from a2a.types import AgentCard - from a2a.types import Message as A2AMessage - from a2a.types import Part as A2APart - from a2a.types import Role - from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent - from a2a.types import TaskState - from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent - from a2a.types import TransportProtocol as A2ATransport -except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - "A2A requires Python 3.10 or above. Please upgrade your Python version." - ) from e - else: - raise e +from a2a.client import Client as A2AClient +from a2a.client import ClientEvent as A2AClientEvent +from a2a.client.card_resolver import A2ACardResolver +from a2a.client.client import ClientConfig as A2AClientConfig +from a2a.client.client_factory import ClientFactory as A2AClientFactory +from a2a.client.errors import A2AClientError +from a2a.types import AgentCard +from a2a.types import Message as A2AMessage +from a2a.types import Part as A2APart +from a2a.types import Role +from a2a.types import TaskArtifactUpdateEvent as A2ATaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent as A2ATaskStatusUpdateEvent +from a2a.types import TransportProtocol as A2ATransport +from google.genai import types as genai_types +import httpx try: from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH @@ -57,9 +49,6 @@ # Fallback for older versions of a2a-sdk. AGENT_CARD_WELL_KNOWN_PATH = "/.well-known/agent.json" -from google.genai import types as genai_types -import httpx - from ..a2a.converters.event_converter import convert_a2a_message_to_event from ..a2a.converters.event_converter import convert_a2a_task_to_event from ..a2a.converters.event_converter import convert_event_to_a2a_message diff --git a/src/google/adk/cli/fast_api.py b/src/google/adk/cli/fast_api.py index f9170968fd..df06b1cf4c 100644 --- a/src/google/adk/cli/fast_api.py +++ b/src/google/adk/cli/fast_api.py @@ -342,25 +342,14 @@ async def get_agent_builder( ) if a2a: - try: - from a2a.server.apps import A2AStarletteApplication - from a2a.server.request_handlers import DefaultRequestHandler - from a2a.server.tasks import InMemoryTaskStore - from a2a.types import AgentCard - from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH - - from ..a2a.executor.a2a_agent_executor import A2aAgentExecutor - - except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - "A2A requires Python 3.10 or above. Please upgrade your Python" - " version." - ) from e - else: - raise e + from a2a.server.apps import A2AStarletteApplication + from a2a.server.request_handlers import DefaultRequestHandler + from a2a.server.tasks import InMemoryTaskStore + from a2a.types import AgentCard + from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH + + from ..a2a.executor.a2a_agent_executor import A2aAgentExecutor + # locate all a2a agent apps in the agents directory base_path = Path.cwd() / agents_dir # the root agents directory should be an existing folder diff --git a/src/google/adk/tools/crewai_tool.py b/src/google/adk/tools/crewai_tool.py index eaef479274..875b82e5b9 100644 --- a/src/google/adk/tools/crewai_tool.py +++ b/src/google/adk/tools/crewai_tool.py @@ -30,16 +30,9 @@ try: from crewai.tools import BaseTool as CrewaiBaseTool except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - 'Crewai Tools require Python 3.10+. Please upgrade your Python version.' - ) from e - else: - raise ImportError( - "Crewai Tools require pip install 'google-adk[extensions]'." - ) from e + raise ImportError( + "Crewai Tools require pip install 'google-adk[extensions]'." + ) from e class CrewaiTool(FunctionTool): diff --git a/src/google/adk/tools/mcp_tool/__init__.py b/src/google/adk/tools/mcp_tool/__init__.py index f1e56b99c4..1170b2e1af 100644 --- a/src/google/adk/tools/mcp_tool/__init__.py +++ b/src/google/adk/tools/mcp_tool/__init__.py @@ -39,15 +39,7 @@ except ImportError as e: import logging - import sys logger = logging.getLogger('google_adk.' + __name__) - - if sys.version_info < (3, 10): - logger.warning( - 'MCP Tool requires Python 3.10 or above. Please upgrade your Python' - ' version.' - ) - else: - logger.debug('MCP Tool is not installed') - logger.debug(e) + logger.debug('MCP Tool is not installed') + logger.debug(e) diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index 7d9714aada..c9c4c2ae66 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -29,24 +29,13 @@ from typing import Union import anyio +from mcp import ClientSession +from mcp import StdioServerParameters +from mcp.client.sse import sse_client +from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import streamablehttp_client from pydantic import BaseModel -try: - from mcp import ClientSession - from mcp import StdioServerParameters - from mcp.client.sse import sse_client - from mcp.client.stdio import stdio_client - from mcp.client.streamable_http import streamablehttp_client -except ImportError as e: - - if sys.version_info < (3, 10): - raise ImportError( - 'MCP Tool requires Python 3.10 or above. Please upgrade your Python' - ' version.' - ) from e - else: - raise e - logger = logging.getLogger('google_adk.' + __name__) diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index 284aea4105..b15f2c73fe 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -17,7 +17,6 @@ import base64 import inspect import logging -import sys from typing import Any from typing import Callable from typing import Dict @@ -27,35 +26,21 @@ from fastapi.openapi.models import APIKeyIn from google.genai.types import FunctionDeclaration +from mcp.types import Tool as McpBaseTool from typing_extensions import override from ...agents.readonly_context import ReadonlyContext -from ...features import FeatureName -from ...features import is_feature_enabled -from .._gemini_schema_util import _to_gemini_schema -from .mcp_session_manager import MCPSessionManager -from .mcp_session_manager import retry_on_errors - -# Attempt to import MCP Tool from the MCP library, and hints user to upgrade -# their Python version to 3.10 if it fails. -try: - from mcp.types import Tool as McpBaseTool -except ImportError as e: - if sys.version_info < (3, 10): - raise ImportError( - "MCP Tool requires Python 3.10 or above. Please upgrade your Python" - " version." - ) from e - else: - raise e - - from ...auth.auth_credential import AuthCredential from ...auth.auth_schemes import AuthScheme from ...auth.auth_tool import AuthConfig +from ...features import FeatureName +from ...features import is_feature_enabled +from .._gemini_schema_util import _to_gemini_schema from ..base_authenticated_tool import BaseAuthenticatedTool # import from ..tool_context import ToolContext +from .mcp_session_manager import MCPSessionManager +from .mcp_session_manager import retry_on_errors logger = logging.getLogger("google_adk." + __name__) diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index 3768477e1d..035b75878b 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -25,6 +25,8 @@ from typing import Union import warnings +from mcp import StdioServerParameters +from mcp.types import ListToolsResult from pydantic import model_validator from typing_extensions import override @@ -41,23 +43,6 @@ from .mcp_session_manager import SseConnectionParams from .mcp_session_manager import StdioConnectionParams from .mcp_session_manager import StreamableHTTPConnectionParams - -# Attempt to import MCP Tool from the MCP library, and hints user to upgrade -# their Python version to 3.10 if it fails. -try: - from mcp import StdioServerParameters - from mcp.types import ListToolsResult -except ImportError as e: - import sys - - if sys.version_info < (3, 10): - raise ImportError( - "MCP Tool requires Python 3.10 or above. Please upgrade your Python" - " version." - ) from e - else: - raise e - from .mcp_tool import MCPTool logger = logging.getLogger("google_adk." + __name__) diff --git a/src/google/adk/utils/context_utils.py b/src/google/adk/utils/context_utils.py index bd8dacb9d8..a75feae3dd 100644 --- a/src/google/adk/utils/context_utils.py +++ b/src/google/adk/utils/context_utils.py @@ -20,30 +20,7 @@ from __future__ import annotations -from contextlib import AbstractAsyncContextManager -from typing import Any -from typing import AsyncGenerator +from contextlib import aclosing - -class Aclosing(AbstractAsyncContextManager): - """Async context manager for safely finalizing an asynchronously cleaned-up - resource such as an async generator, calling its ``aclose()`` method. - Needed to correctly close contexts for OTel spans. - See https://github.com/google/adk-python/issues/1670#issuecomment-3115891100. - - Based on - https://docs.python.org/3/library/contextlib.html#contextlib.aclosing - which is available in Python 3.10+. - - TODO: replace all occurrences with contextlib.aclosing once Python 3.9 is no - longer supported. - """ - - def __init__(self, async_generator: AsyncGenerator[Any, None]): - self.async_generator = async_generator - - async def __aenter__(self): - return self.async_generator - - async def __aexit__(self, *exc_info): - await self.async_generator.aclose() +# Re-export aclosing for backward compatibility +Aclosing = aclosing diff --git a/tests/unittests/a2a/converters/test_event_converter.py b/tests/unittests/a2a/converters/test_event_converter.py index cb3f7a6858..49b7d3c2b6 100644 --- a/tests/unittests/a2a/converters/test_event_converter.py +++ b/tests/unittests/a2a/converters/test_event_converter.py @@ -12,50 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import Mock from unittest.mock import patch +from a2a.types import DataPart +from a2a.types import Message +from a2a.types import Role +from a2a.types import Task +from a2a.types import TaskState +from a2a.types import TaskStatusUpdateEvent +from google.adk.a2a.converters.event_converter import _create_artifact_id +from google.adk.a2a.converters.event_converter import _create_error_status_event +from google.adk.a2a.converters.event_converter import _create_status_update_event +from google.adk.a2a.converters.event_converter import _get_adk_metadata_key +from google.adk.a2a.converters.event_converter import _get_context_metadata +from google.adk.a2a.converters.event_converter import _process_long_running_tool +from google.adk.a2a.converters.event_converter import _serialize_metadata_value +from google.adk.a2a.converters.event_converter import ARTIFACT_ID_SEPARATOR +from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event +from google.adk.a2a.converters.event_converter import convert_event_to_a2a_events +from google.adk.a2a.converters.event_converter import convert_event_to_a2a_message +from google.adk.a2a.converters.event_converter import DEFAULT_ERROR_MESSAGE +from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX +from google.adk.agents.invocation_context import InvocationContext +from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.types import DataPart - from a2a.types import Message - from a2a.types import Role - from a2a.types import Task - from a2a.types import TaskState - from a2a.types import TaskStatusUpdateEvent - from google.adk.a2a.converters.event_converter import _create_artifact_id - from google.adk.a2a.converters.event_converter import _create_error_status_event - from google.adk.a2a.converters.event_converter import _create_status_update_event - from google.adk.a2a.converters.event_converter import _get_adk_metadata_key - from google.adk.a2a.converters.event_converter import _get_context_metadata - from google.adk.a2a.converters.event_converter import _process_long_running_tool - from google.adk.a2a.converters.event_converter import _serialize_metadata_value - from google.adk.a2a.converters.event_converter import ARTIFACT_ID_SEPARATOR - from google.adk.a2a.converters.event_converter import convert_a2a_task_to_event - from google.adk.a2a.converters.event_converter import convert_event_to_a2a_events - from google.adk.a2a.converters.event_converter import convert_event_to_a2a_message - from google.adk.a2a.converters.event_converter import DEFAULT_ERROR_MESSAGE - from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX - from google.adk.agents.invocation_context import InvocationContext - from google.adk.events.event import Event - from google.adk.events.event_actions import EventActions -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestEventConverter: """Test suite for event_converter module.""" diff --git a/tests/unittests/a2a/converters/test_part_converter.py b/tests/unittests/a2a/converters/test_part_converter.py index 5a8bad1096..541ab7709d 100644 --- a/tests/unittests/a2a/converters/test_part_converter.py +++ b/tests/unittests/a2a/converters/test_part_converter.py @@ -13,38 +13,21 @@ # limitations under the License. import json -import sys from unittest.mock import Mock from unittest.mock import patch +from a2a import types as a2a_types +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE +from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY +from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part +from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.genai import types as genai_types import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a import types as a2a_types - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE - from google.adk.a2a.converters.part_converter import A2A_DATA_PART_METADATA_TYPE_KEY - from google.adk.a2a.converters.part_converter import convert_a2a_part_to_genai_part - from google.adk.a2a.converters.part_converter import convert_genai_part_to_a2a_part - from google.adk.a2a.converters.utils import _get_adk_metadata_key - from google.genai import types as genai_types -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestConvertA2aPartToGenaiPart: """Test cases for convert_a2a_part_to_genai_part function.""" diff --git a/tests/unittests/a2a/converters/test_request_converter.py b/tests/unittests/a2a/converters/test_request_converter.py index a7c21e4dbc..173b122d7c 100644 --- a/tests/unittests/a2a/converters/test_request_converter.py +++ b/tests/unittests/a2a/converters/test_request_converter.py @@ -12,33 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import Mock from unittest.mock import patch +from a2a.server.agent_execution import RequestContext +from google.adk.a2a.converters.request_converter import _get_user_id +from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request +from google.adk.runners import RunConfig +from google.genai import types as genai_types import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.server.agent_execution import RequestContext - from google.adk.a2a.converters.request_converter import _get_user_id - from google.adk.a2a.converters.request_converter import convert_a2a_request_to_agent_run_request - from google.adk.runners import RunConfig - from google.genai import types as genai_types -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestGetUserId: """Test cases for _get_user_id function.""" diff --git a/tests/unittests/a2a/converters/test_utils.py b/tests/unittests/a2a/converters/test_utils.py index 6c8511161a..0d896852aa 100644 --- a/tests/unittests/a2a/converters/test_utils.py +++ b/tests/unittests/a2a/converters/test_utils.py @@ -12,31 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys - +from google.adk.a2a.converters.utils import _from_a2a_context_id +from google.adk.a2a.converters.utils import _get_adk_metadata_key +from google.adk.a2a.converters.utils import _to_a2a_context_id +from google.adk.a2a.converters.utils import ADK_CONTEXT_ID_PREFIX +from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.a2a.converters.utils import _from_a2a_context_id - from google.adk.a2a.converters.utils import _get_adk_metadata_key - from google.adk.a2a.converters.utils import _to_a2a_context_id - from google.adk.a2a.converters.utils import ADK_CONTEXT_ID_PREFIX - from google.adk.a2a.converters.utils import ADK_METADATA_KEY_PREFIX -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestUtilsFunctions: """Test suite for utils module functions.""" diff --git a/tests/unittests/a2a/executor/test_a2a_agent_executor.py b/tests/unittests/a2a/executor/test_a2a_agent_executor.py index 4bcc7a91d7..58d7521f7d 100644 --- a/tests/unittests/a2a/executor/test_a2a_agent_executor.py +++ b/tests/unittests/a2a/executor/test_a2a_agent_executor.py @@ -12,41 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import Message +from a2a.types import TaskState +from a2a.types import TextPart +from google.adk.a2a.converters.request_converter import AgentRunRequest +from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor +from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig +from google.adk.events.event import Event +from google.adk.runners import RunConfig +from google.adk.runners import Runner +from google.genai.types import Content import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.server.agent_execution.context import RequestContext - from a2a.server.events.event_queue import EventQueue - from a2a.types import Message - from a2a.types import TaskState - from a2a.types import TextPart - from google.adk.a2a.converters.request_converter import AgentRunRequest - from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor - from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutorConfig - from google.adk.events.event import Event - from google.adk.runners import RunConfig - from google.adk.runners import Runner - from google.genai.types import Content -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestA2aAgentExecutor: """Test suite for A2aAgentExecutor class.""" diff --git a/tests/unittests/a2a/executor/test_task_result_aggregator.py b/tests/unittests/a2a/executor/test_task_result_aggregator.py index 9d03db9dc8..b809b62728 100644 --- a/tests/unittests/a2a/executor/test_task_result_aggregator.py +++ b/tests/unittests/a2a/executor/test_task_result_aggregator.py @@ -12,35 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import Mock +from a2a.types import Message +from a2a.types import Part +from a2a.types import Role +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.a2a.executor.task_result_aggregator import TaskResultAggregator import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.types import Message - from a2a.types import Part - from a2a.types import Role - from a2a.types import TaskState - from a2a.types import TaskStatus - from a2a.types import TaskStatusUpdateEvent - from a2a.types import TextPart - from google.adk.a2a.executor.task_result_aggregator import TaskResultAggregator -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - def create_test_message(text: str): """Helper function to create a test Message object.""" diff --git a/tests/unittests/a2a/utils/test_agent_card_builder.py b/tests/unittests/a2a/utils/test_agent_card_builder.py index e0b62468e5..3bf3202897 100644 --- a/tests/unittests/a2a/utils/test_agent_card_builder.py +++ b/tests/unittests/a2a/utils/test_agent_card_builder.py @@ -12,55 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import Mock from unittest.mock import patch +from a2a.types import AgentCapabilities +from a2a.types import AgentCard +from a2a.types import AgentProvider +from a2a.types import AgentSkill +from a2a.types import SecurityScheme +from google.adk.a2a.utils.agent_card_builder import _build_agent_description +from google.adk.a2a.utils.agent_card_builder import _build_llm_agent_description_with_instructions +from google.adk.a2a.utils.agent_card_builder import _build_loop_description +from google.adk.a2a.utils.agent_card_builder import _build_orchestration_skill +from google.adk.a2a.utils.agent_card_builder import _build_parallel_description +from google.adk.a2a.utils.agent_card_builder import _build_sequential_description +from google.adk.a2a.utils.agent_card_builder import _convert_example_tool_examples +from google.adk.a2a.utils.agent_card_builder import _extract_examples_from_instruction +from google.adk.a2a.utils.agent_card_builder import _get_agent_skill_name +from google.adk.a2a.utils.agent_card_builder import _get_agent_type +from google.adk.a2a.utils.agent_card_builder import _get_default_description +from google.adk.a2a.utils.agent_card_builder import _get_input_modes +from google.adk.a2a.utils.agent_card_builder import _get_output_modes +from google.adk.a2a.utils.agent_card_builder import _get_workflow_description +from google.adk.a2a.utils.agent_card_builder import _replace_pronouns +from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.loop_agent import LoopAgent +from google.adk.agents.parallel_agent import ParallelAgent +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.tools.example_tool import ExampleTool import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.types import AgentCapabilities - from a2a.types import AgentCard - from a2a.types import AgentProvider - from a2a.types import AgentSkill - from a2a.types import SecurityScheme - from google.adk.a2a.utils.agent_card_builder import _build_agent_description - from google.adk.a2a.utils.agent_card_builder import _build_llm_agent_description_with_instructions - from google.adk.a2a.utils.agent_card_builder import _build_loop_description - from google.adk.a2a.utils.agent_card_builder import _build_orchestration_skill - from google.adk.a2a.utils.agent_card_builder import _build_parallel_description - from google.adk.a2a.utils.agent_card_builder import _build_sequential_description - from google.adk.a2a.utils.agent_card_builder import _convert_example_tool_examples - from google.adk.a2a.utils.agent_card_builder import _extract_examples_from_instruction - from google.adk.a2a.utils.agent_card_builder import _get_agent_skill_name - from google.adk.a2a.utils.agent_card_builder import _get_agent_type - from google.adk.a2a.utils.agent_card_builder import _get_default_description - from google.adk.a2a.utils.agent_card_builder import _get_input_modes - from google.adk.a2a.utils.agent_card_builder import _get_output_modes - from google.adk.a2a.utils.agent_card_builder import _get_workflow_description - from google.adk.a2a.utils.agent_card_builder import _replace_pronouns - from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder - from google.adk.agents.base_agent import BaseAgent - from google.adk.agents.llm_agent import LlmAgent - from google.adk.agents.loop_agent import LoopAgent - from google.adk.agents.parallel_agent import ParallelAgent - from google.adk.agents.sequential_agent import SequentialAgent - from google.adk.tools.example_tool import ExampleTool -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e - class TestAgentCardBuilder: """Test suite for AgentCardBuilder class.""" diff --git a/tests/unittests/a2a/utils/test_agent_to_a2a.py b/tests/unittests/a2a/utils/test_agent_to_a2a.py index ee80b0233b..503e572f2f 100644 --- a/tests/unittests/a2a/utils/test_agent_to_a2a.py +++ b/tests/unittests/a2a/utils/test_agent_to_a2a.py @@ -12,42 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch +from a2a.server.apps import A2AStarletteApplication +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import InMemoryTaskStore +from a2a.types import AgentCard +from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor +from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder +from google.adk.a2a.utils.agent_to_a2a import to_a2a +from google.adk.agents.base_agent import BaseAgent +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService import pytest - -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.server.apps import A2AStarletteApplication - from a2a.server.request_handlers import DefaultRequestHandler - from a2a.server.tasks import InMemoryTaskStore - from a2a.types import AgentCard - from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor - from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder - from google.adk.a2a.utils.agent_to_a2a import to_a2a - from google.adk.agents.base_agent import BaseAgent - from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService - from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService - from google.adk.memory.in_memory_memory_service import InMemoryMemoryService - from google.adk.runners import Runner - from google.adk.sessions.in_memory_session_service import InMemorySessionService - from starlette.applications import Starlette -except ImportError as e: - if sys.version_info < (3, 10): - # Imports are not needed since tests will be skipped due to pytestmark. - # The imported names are only used within test methods, not at module level, - # so no NameError occurs during module compilation. - pass - else: - raise e +from starlette.applications import Starlette class TestToA2A: diff --git a/tests/unittests/agents/test_mcp_instruction_provider.py b/tests/unittests/agents/test_mcp_instruction_provider.py index 1f2d098c2a..256d812630 100644 --- a/tests/unittests/agents/test_mcp_instruction_provider.py +++ b/tests/unittests/agents/test_mcp_instruction_provider.py @@ -13,34 +13,14 @@ # limitations under the License. """Unit tests for McpInstructionProvider.""" -import sys from unittest.mock import AsyncMock from unittest.mock import MagicMock from unittest.mock import patch +from google.adk.agents.mcp_instruction_provider import McpInstructionProvider from google.adk.agents.readonly_context import ReadonlyContext import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), - reason="MCP instruction provider requires Python 3.10+", -) - -# Import dependencies with version checking -try: - from google.adk.agents.mcp_instruction_provider import McpInstructionProvider -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - McpInstructionProvider = DummyClass - else: - raise e - class TestMcpInstructionProvider: """Unit tests for McpInstructionProvider.""" diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index fd722abf3f..e7865f39ba 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -14,70 +14,38 @@ import json from pathlib import Path -import sys import tempfile from unittest.mock import AsyncMock from unittest.mock import create_autospec from unittest.mock import Mock from unittest.mock import patch +from a2a.client.client import ClientConfig +from a2a.client.client import Consumer +from a2a.client.client_factory import ClientFactory +from a2a.types import AgentCapabilities +from a2a.types import AgentCard +from a2a.types import AgentSkill +from a2a.types import Artifact +from a2a.types import Message as A2AMessage +from a2a.types import Part as A2ATaskStatus +from a2a.types import SendMessageSuccessResponse +from a2a.types import Task as A2ATask +from a2a.types import TaskArtifactUpdateEvent +from a2a.types import TaskState +from a2a.types import TaskStatus +from a2a.types import TaskStatusUpdateEvent +from a2a.types import TextPart +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX +from google.adk.agents.remote_a2a_agent import AgentCardResolutionError +from google.adk.agents.remote_a2a_agent import RemoteA2aAgent from google.adk.events.event import Event from google.adk.sessions.session import Session from google.genai import types as genai_types import httpx import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from a2a.client.client import ClientConfig - from a2a.client.client import Consumer - from a2a.client.client_factory import ClientFactory - from a2a.types import AgentCapabilities - from a2a.types import AgentCard - from a2a.types import AgentSkill - from a2a.types import Artifact - from a2a.types import Message as A2AMessage - from a2a.types import Part as A2ATaskStatus - from a2a.types import SendMessageSuccessResponse - from a2a.types import Task as A2ATask - from a2a.types import TaskArtifactUpdateEvent - from a2a.types import TaskState - from a2a.types import TaskStatus - from a2a.types import TaskStatusUpdateEvent - from a2a.types import TextPart - from google.adk.agents.invocation_context import InvocationContext - from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX - from google.adk.agents.remote_a2a_agent import AgentCardResolutionError - from google.adk.agents.remote_a2a_agent import RemoteA2aAgent -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during module compilation. - # These are needed because the module has type annotations and module-level - # helper functions that reference imported types. - class DummyTypes: - pass - - AgentCapabilities = DummyTypes() - AgentCard = DummyTypes() - AgentSkill = DummyTypes() - A2AMessage = DummyTypes() - SendMessageSuccessResponse = DummyTypes() - A2ATask = DummyTypes() - TaskStatusUpdateEvent = DummyTypes() - Artifact = DummyTypes() - TaskArtifactUpdateEvent = DummyTypes() - InvocationContext = DummyTypes() - RemoteA2aAgent = DummyTypes() - AgentCardResolutionError = Exception - A2A_METADATA_PREFIX = "" - else: - raise e - # Helper function to create a proper AgentCard for testing def create_test_agent_card( diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 1fe04732f5..75d5679084 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -509,8 +509,6 @@ async def create_test_eval_set( @pytest.fixture def temp_agents_dir_with_a2a(): """Create a temporary agents directory with A2A agent configurations for testing.""" - if sys.version_info < (3, 10): - pytest.skip("A2A requires Python 3.10+") with tempfile.TemporaryDirectory() as temp_dir: # Create test agent directory agent_dir = Path(temp_dir) / "test_a2a_agent" @@ -554,9 +552,6 @@ def test_app_with_a2a( temp_agents_dir_with_a2a, ): """Create a TestClient for the FastAPI app with A2A enabled.""" - if sys.version_info < (3, 10): - pytest.skip("A2A requires Python 3.10+") - # Mock A2A related classes with ( patch("signal.signal", return_value=None), @@ -1150,9 +1145,6 @@ def list_agents(self): assert "dotSrc" in response.json() -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) def test_a2a_agent_discovery(test_app_with_a2a): """Test that A2A agents are properly discovered and configured.""" # This test mainly verifies that the A2A setup doesn't break the app @@ -1161,9 +1153,6 @@ def test_a2a_agent_discovery(test_app_with_a2a): logger.info("A2A agent discovery test passed") -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="A2A requires Python 3.10+" -) def test_a2a_disabled_by_default(test_app): """Test that A2A functionality is disabled by default.""" # The regular test_app fixture has a2a=False diff --git a/tests/unittests/evaluation/test_local_eval_service.py b/tests/unittests/evaluation/test_local_eval_service.py index cf2ca342f3..66080828d8 100644 --- a/tests/unittests/evaluation/test_local_eval_service.py +++ b/tests/unittests/evaluation/test_local_eval_service.py @@ -536,9 +536,6 @@ def test_generate_final_eval_status_doesn_t_throw_on(eval_service): @pytest.mark.asyncio -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) async def test_mcp_stdio_agent_no_runtime_error(mocker): """Test that LocalEvalService can handle MCP stdio agents without RuntimeError. diff --git a/tests/unittests/plugins/test_reflect_retry_tool_plugin.py b/tests/unittests/plugins/test_reflect_retry_tool_plugin.py index 1e15f33899..2cf52e99cb 100644 --- a/tests/unittests/plugins/test_reflect_retry_tool_plugin.py +++ b/tests/unittests/plugins/test_reflect_retry_tool_plugin.py @@ -57,10 +57,8 @@ async def extract_error_from_result( return None -# Inheriting from IsolatedAsyncioTestCase ensures these tests works in Python -# 3.9. See https://github.com/pytest-dev/pytest-asyncio/issues/1039 -# Without this, the tests will fail with a "RuntimeError: There is no current -# event loop in thread 'MainThread'." +# Inheriting from IsolatedAsyncioTestCase ensures consistent behavior. +# See https://github.com/pytest-dev/pytest-asyncio/issues/1039 class TestReflectAndRetryToolPlugin(IsolatedAsyncioTestCase): """Comprehensive tests for ReflectAndRetryToolPlugin focusing on behavior.""" diff --git a/tests/unittests/telemetry/test_functional.py b/tests/unittests/telemetry/test_functional.py index 409571ad1f..43fe672333 100644 --- a/tests/unittests/telemetry/test_functional.py +++ b/tests/unittests/telemetry/test_functional.py @@ -103,7 +103,7 @@ def wrapped_firstiter(coro): isinstance(referrer, Aclosing) or isinstance(indirect_referrer, Aclosing) for referrer in gc.get_referrers(coro) - # Some coroutines have a layer of indirection in python 3.9 and 3.10 + # Some coroutines have a layer of indirection in Python 3.10 for indirect_referrer in gc.get_referrers(referrer) ), f'Coro `{coro.__name__}` is not wrapped with Aclosing' firstiter(coro) diff --git a/tests/unittests/tools/computer_use/test_computer_use_tool.py b/tests/unittests/tools/computer_use/test_computer_use_tool.py index 4dbdfbb5c0..f3843b87a6 100644 --- a/tests/unittests/tools/computer_use/test_computer_use_tool.py +++ b/tests/unittests/tools/computer_use/test_computer_use_tool.py @@ -47,7 +47,7 @@ async def tool_context(self): @pytest.fixture def mock_computer_function(self): """Fixture providing a mock computer function.""" - # Create a real async function instead of AsyncMock for Python 3.9 compatibility + # Create a real async function instead of AsyncMock for better test control calls = [] async def mock_func(*args, **kwargs): diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index b2d6b1cb88..74eabe9d4d 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -22,46 +22,14 @@ from unittest.mock import Mock from unittest.mock import patch +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_errors +from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from mcp import StdioServerParameters import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager - from google.adk.tools.mcp_tool.mcp_session_manager import retry_on_errors - from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams - from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams - from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - MCPSessionManager = DummyClass - retry_on_errors = lambda x: x - SseConnectionParams = DummyClass - StdioConnectionParams = DummyClass - StreamableHTTPConnectionParams = DummyClass - else: - raise e - -# Import real MCP classes -try: - from mcp import StdioServerParameters -except ImportError: - # Create a mock if MCP is not available - class StdioServerParameters: - - def __init__(self, command="test_command", args=None): - self.command = command - self.args = args or [] - class MockClientSession: """Mock ClientSession for testing.""" diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index 17b1d8e54e..1284e73bce 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch @@ -23,39 +22,15 @@ from google.adk.auth.auth_credential import HttpCredentials from google.adk.auth.auth_credential import OAuth2Auth from google.adk.auth.auth_credential import ServiceAccount +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_tool import MCPTool +from google.adk.tools.tool_context import ToolContext +from google.genai.types import FunctionDeclaration +from google.genai.types import Type +from mcp.types import CallToolResult +from mcp.types import TextContent import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager - from google.adk.tools.mcp_tool.mcp_tool import MCPTool - from google.adk.tools.tool_context import ToolContext - from google.genai.types import FunctionDeclaration - from google.genai.types import Type - from mcp.types import CallToolResult - from mcp.types import TextContent -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - MCPSessionManager = DummyClass - MCPTool = DummyClass - ToolContext = DummyClass - FunctionDeclaration = DummyClass - Type = DummyClass - CallToolResult = DummyClass - TextContent = DummyClass - else: - raise e - # Mock MCP Tool from mcp.types class MockMCPTool: diff --git a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py index 82a5c9a3e7..5809efe56f 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_toolset.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_toolset.py @@ -20,47 +20,17 @@ from unittest.mock import Mock from unittest.mock import patch +from google.adk.agents.readonly_context import ReadonlyContext from google.adk.auth.auth_credential import AuthCredential +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager +from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams +from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams +from google.adk.tools.mcp_tool.mcp_tool import MCPTool +from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset +from mcp import StdioServerParameters import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.agents.readonly_context import ReadonlyContext - from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager - from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams - from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams - from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams - from google.adk.tools.mcp_tool.mcp_tool import MCPTool - from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset - from mcp import StdioServerParameters -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - class StdioServerParameters: - - def __init__(self, command="test_command", args=None): - self.command = command - self.args = args or [] - - MCPSessionManager = DummyClass - SseConnectionParams = DummyClass - StdioConnectionParams = DummyClass - StreamableHTTPConnectionParams = DummyClass - MCPTool = DummyClass - MCPToolset = DummyClass - ReadonlyContext = DummyClass - else: - raise e - class MockMCPTool: """Mock MCP Tool for testing.""" diff --git a/tests/unittests/tools/retrieval/test_files_retrieval.py b/tests/unittests/tools/retrieval/test_files_retrieval.py index ea4b99cd98..dfb7215dce 100644 --- a/tests/unittests/tools/retrieval/test_files_retrieval.py +++ b/tests/unittests/tools/retrieval/test_files_retrieval.py @@ -14,7 +14,6 @@ """Tests for FilesRetrieval tool.""" -import sys import unittest.mock as mock from google.adk.tools.retrieval.files_retrieval import _get_default_embedding_model @@ -111,9 +110,6 @@ def mock_import(name, *args, **kwargs): def test_get_default_embedding_model_success(self): """Test _get_default_embedding_model returns Google embedding when available.""" - # Skip this test in Python 3.9 where llama_index.embeddings.google_genai may not be available - if sys.version_info < (3, 10): - pytest.skip("llama_index.embeddings.google_genai requires Python 3.10+") # Mock the module creation to avoid import issues mock_module = mock.MagicMock() diff --git a/tests/unittests/tools/test_mcp_toolset.py b/tests/unittests/tools/test_mcp_toolset.py index a3a6598e35..7bfd912669 100644 --- a/tests/unittests/tools/test_mcp_toolset.py +++ b/tests/unittests/tools/test_mcp_toolset.py @@ -14,31 +14,12 @@ """Unit tests for McpToolset.""" -import sys from unittest.mock import AsyncMock from unittest.mock import MagicMock +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset import pytest -# Skip all tests in this module if Python version is less than 3.10 -pytestmark = pytest.mark.skipif( - sys.version_info < (3, 10), reason="MCP tool requires Python 3.10+" -) - -# Import dependencies with version checking -try: - from google.adk.tools.mcp_tool.mcp_toolset import McpToolset -except ImportError as e: - if sys.version_info < (3, 10): - # Create dummy classes to prevent NameError during test collection - # Tests will be skipped anyway due to pytestmark - class DummyClass: - pass - - McpToolset = DummyClass - else: - raise e - @pytest.mark.asyncio async def test_mcp_toolset_with_prefix(): From 9d918d45df4275b5b464e46817d2daaa03859fe3 Mon Sep 17 00:00:00 2001 From: Shangjie Chen Date: Wed, 3 Dec 2025 10:39:17 -0800 Subject: [PATCH 61/63] feat!: Rollback the DB migration as it is breaking Co-authored-by: Shangjie Chen PiperOrigin-RevId: 839818479 --- src/google/adk/cli/cli_tools_click.py | 36 -- .../adk/sessions/database_session_service.py | 222 +++++--- .../migrate_from_sqlalchemy_sqlite.py | 0 .../adk/sessions/migration/_schema_check.py | 114 ---- .../migrate_from_sqlalchemy_pickle.py | 492 ------------------ .../sessions/migration/migration_runner.py | 128 ----- .../adk/sessions/sqlite_session_service.py | 2 +- .../sessions/migration/test_migrations.py | 106 ---- .../sessions/test_dynamic_pickle_type.py | 181 +++++++ 9 files changed, 342 insertions(+), 939 deletions(-) rename src/google/adk/sessions/{migration => }/migrate_from_sqlalchemy_sqlite.py (100%) delete mode 100644 src/google/adk/sessions/migration/_schema_check.py delete mode 100644 src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py delete mode 100644 src/google/adk/sessions/migration/migration_runner.py delete mode 100644 tests/unittests/sessions/migration/test_migrations.py create mode 100644 tests/unittests/sessions/test_dynamic_pickle_type.py diff --git a/src/google/adk/cli/cli_tools_click.py b/src/google/adk/cli/cli_tools_click.py index c4446278b4..5d228f72f3 100644 --- a/src/google/adk/cli/cli_tools_click.py +++ b/src/google/adk/cli/cli_tools_click.py @@ -36,7 +36,6 @@ from . import cli_deploy from .. import version from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE -from ..sessions.migration import migration_runner from .cli import run_cli from .fast_api import get_fast_api_app from .utils import envs @@ -1500,41 +1499,6 @@ def cli_deploy_cloud_run( click.secho(f"Deploy failed: {e}", fg="red", err=True) -@main.group() -def migrate(): - """Migrate ADK database schemas.""" - pass - - -@migrate.command("session", cls=HelpfulCommand) -@click.option( - "--source_db_url", - required=True, - help="SQLAlchemy URL of source database.", -) -@click.option( - "--dest_db_url", - required=True, - help="SQLAlchemy URL of destination database.", -) -@click.option( - "--log_level", - type=LOG_LEVELS, - default="INFO", - help="Optional. Set the logging level", -) -def cli_migrate_session( - *, source_db_url: str, dest_db_url: str, log_level: str -): - """Migrates a session database to the latest schema version.""" - logs.setup_adk_logger(getattr(logging, log_level.upper())) - try: - migration_runner.upgrade(source_db_url, dest_db_url) - click.secho("Migration check and upgrade process finished.", fg="green") - except Exception as e: - click.secho(f"Migration failed: {e}", fg="red", err=True) - - @deploy.command("agent_engine") @click.option( "--api_key", diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 1576151f23..a352918211 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -19,16 +19,18 @@ from datetime import timezone import json import logging +import pickle from typing import Any from typing import Optional import uuid +from google.genai import types +from sqlalchemy import Boolean from sqlalchemy import delete from sqlalchemy import Dialect from sqlalchemy import event from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func -from sqlalchemy import inspect from sqlalchemy import select from sqlalchemy import Text from sqlalchemy.dialects import mysql @@ -39,11 +41,14 @@ from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.inspection import inspect from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship +from sqlalchemy.schema import MetaData from sqlalchemy.types import DateTime +from sqlalchemy.types import PickleType from sqlalchemy.types import String from sqlalchemy.types import TypeDecorator from typing_extensions import override @@ -52,10 +57,10 @@ from . import _session_util from ..errors.already_exists_error import AlreadyExistsError from ..events.event import Event +from ..events.event_actions import EventActions from .base_session_service import BaseSessionService from .base_session_service import GetSessionConfig from .base_session_service import ListSessionsResponse -from .migration import _schema_check from .session import Session from .state import State @@ -106,20 +111,39 @@ def load_dialect_impl(self, dialect): return self.impl -class Base(DeclarativeBase): - """Base class for database tables.""" +class DynamicPickleType(TypeDecorator): + """Represents a type that can be pickled.""" - pass + impl = PickleType + def load_dialect_impl(self, dialect): + if dialect.name == "mysql": + return dialect.type_descriptor(mysql.LONGBLOB) + if dialect.name == "spanner+spanner": + from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType -class StorageMetadata(Base): - """Represents internal metadata stored in the database.""" + return dialect.type_descriptor(SpannerPickleType) + return self.impl + + def process_bind_param(self, value, dialect): + """Ensures the pickled value is a bytes object before passing it to the database dialect.""" + if value is not None: + if dialect.name in ("spanner+spanner", "mysql"): + return pickle.dumps(value) + return value + + def process_result_value(self, value, dialect): + """Ensures the raw bytes from the database are unpickled back into a Python object.""" + if value is not None: + if dialect.name in ("spanner+spanner", "mysql"): + return pickle.loads(value) + return value - __tablename__ = "adk_internal_metadata" - key: Mapped[str] = mapped_column( - String(DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - value: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + +class Base(DeclarativeBase): + """Base class for database tables.""" + + pass class StorageSession(Base): @@ -213,10 +237,46 @@ class StorageEvent(Base): ) invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH)) + actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) + long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( + Text, nullable=True + ) + branch: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) timestamp: Mapped[PreciseTimestamp] = mapped_column( PreciseTimestamp, default=func.now() ) - event_data: Mapped[dict[str, Any]] = mapped_column(DynamicJSON) + + # === Fields from llm_response.py === + content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) + grounding_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + custom_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + usage_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + citation_metadata: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + + partial: Mapped[bool] = mapped_column(Boolean, nullable=True) + turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) + error_code: Mapped[str] = mapped_column( + String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True + ) + error_message: Mapped[str] = mapped_column(String(1024), nullable=True) + interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) + input_transcription: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) + output_transcription: Mapped[dict[str, Any]] = mapped_column( + DynamicJSON, nullable=True + ) storage_session: Mapped[StorageSession] = relationship( "StorageSession", @@ -231,27 +291,102 @@ class StorageEvent(Base): ), ) + @property + def long_running_tool_ids(self) -> set[str]: + return ( + set(json.loads(self.long_running_tool_ids_json)) + if self.long_running_tool_ids_json + else set() + ) + + @long_running_tool_ids.setter + def long_running_tool_ids(self, value: set[str]): + if value is None: + self.long_running_tool_ids_json = None + else: + self.long_running_tool_ids_json = json.dumps(list(value)) + @classmethod def from_event(cls, session: Session, event: Event) -> StorageEvent: - """Creates a StorageEvent from an Event.""" - return StorageEvent( + storage_event = StorageEvent( id=event.id, invocation_id=event.invocation_id, + author=event.author, + branch=event.branch, + actions=event.actions, session_id=session.id, app_name=session.app_name, user_id=session.user_id, timestamp=datetime.fromtimestamp(event.timestamp), - event_data=event.model_dump(exclude_none=True, mode="json"), + long_running_tool_ids=event.long_running_tool_ids, + partial=event.partial, + turn_complete=event.turn_complete, + error_code=event.error_code, + error_message=event.error_message, + interrupted=event.interrupted, ) + if event.content: + storage_event.content = event.content.model_dump( + exclude_none=True, mode="json" + ) + if event.grounding_metadata: + storage_event.grounding_metadata = event.grounding_metadata.model_dump( + exclude_none=True, mode="json" + ) + if event.custom_metadata: + storage_event.custom_metadata = event.custom_metadata + if event.usage_metadata: + storage_event.usage_metadata = event.usage_metadata.model_dump( + exclude_none=True, mode="json" + ) + if event.citation_metadata: + storage_event.citation_metadata = event.citation_metadata.model_dump( + exclude_none=True, mode="json" + ) + if event.input_transcription: + storage_event.input_transcription = event.input_transcription.model_dump( + exclude_none=True, mode="json" + ) + if event.output_transcription: + storage_event.output_transcription = ( + event.output_transcription.model_dump(exclude_none=True, mode="json") + ) + return storage_event def to_event(self) -> Event: - """Converts the StorageEvent to an Event.""" - return Event.model_validate({ - **self.event_data, - "id": self.id, - "invocation_id": self.invocation_id, - "timestamp": self.timestamp.timestamp(), - }) + return Event( + id=self.id, + invocation_id=self.invocation_id, + author=self.author, + branch=self.branch, + # This is needed as previous ADK version pickled actions might not have + # value defined in the current version of the EventActions model. + actions=EventActions().model_copy(update=self.actions.model_dump()), + timestamp=self.timestamp.timestamp(), + long_running_tool_ids=self.long_running_tool_ids, + partial=self.partial, + turn_complete=self.turn_complete, + error_code=self.error_code, + error_message=self.error_message, + interrupted=self.interrupted, + custom_metadata=self.custom_metadata, + content=_session_util.decode_model(self.content, types.Content), + grounding_metadata=_session_util.decode_model( + self.grounding_metadata, types.GroundingMetadata + ), + usage_metadata=_session_util.decode_model( + self.usage_metadata, types.GenerateContentResponseUsageMetadata + ), + citation_metadata=_session_util.decode_model( + self.citation_metadata, types.CitationMetadata + ), + input_transcription=_session_util.decode_model( + self.input_transcription, types.Transcription + ), + output_transcription=_session_util.decode_model( + self.output_transcription, types.Transcription + ), + ) class StorageAppState(Base): @@ -328,6 +463,7 @@ def __init__(self, db_url: str, **kwargs: Any): logger.info("Local timezone: %s", local_timezone) self.db_engine: AsyncEngine = db_engine + self.metadata: MetaData = MetaData() # DB session factory method self.database_session_factory: async_sessionmaker[ @@ -347,46 +483,10 @@ async def _ensure_tables_created(self): async with self._table_creation_lock: # Double-check after acquiring the lock if not self._tables_created: - # Check schema version BEFORE creating tables. - # This prevents creating metadata table on a v0.1 DB. - async with self.database_session_factory() as sql_session: - version, is_v01 = await sql_session.run_sync( - _schema_check.get_version_and_v01_status_sync - ) - - if is_v01: - raise RuntimeError( - "Database schema appears to be v0.1, but" - f" {_schema_check.CURRENT_SCHEMA_VERSION} is required. Please" - " migrate the database using 'adk migrate session'." - ) - elif version and version < _schema_check.CURRENT_SCHEMA_VERSION: - raise RuntimeError( - f"Database schema version is {version}, but current version is" - f" {_schema_check.CURRENT_SCHEMA_VERSION}. Please migrate" - " the database to the latest version using 'adk migrate" - " session'." - ) - async with self.db_engine.begin() as conn: # Uncomment to recreate DB every time # await conn.run_sync(Base.metadata.drop_all) await conn.run_sync(Base.metadata.create_all) - - # If we are here, DB is either new or >= current version. - # If new or without metadata row, stamp it as current version. - async with self.database_session_factory() as sql_session: - metadata = await sql_session.get( - StorageMetadata, _schema_check.SCHEMA_VERSION_KEY - ) - if not metadata: - sql_session.add( - StorageMetadata( - key=_schema_check.SCHEMA_VERSION_KEY, - value=_schema_check.CURRENT_SCHEMA_VERSION, - ) - ) - await sql_session.commit() self._tables_created = True @override @@ -623,9 +723,7 @@ async def append_event(self, session: Session, event: Event) -> Event: storage_session.state = storage_session.state | session_state_delta if storage_session._dialect_name == "sqlite": - update_time = datetime.fromtimestamp( - event.timestamp, timezone.utc - ).replace(tzinfo=None) + update_time = datetime.utcfromtimestamp(event.timestamp) else: update_time = datetime.fromtimestamp(event.timestamp) storage_session.update_time = update_time diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py b/src/google/adk/sessions/migrate_from_sqlalchemy_sqlite.py similarity index 100% rename from src/google/adk/sessions/migration/migrate_from_sqlalchemy_sqlite.py rename to src/google/adk/sessions/migrate_from_sqlalchemy_sqlite.py diff --git a/src/google/adk/sessions/migration/_schema_check.py b/src/google/adk/sessions/migration/_schema_check.py deleted file mode 100644 index f6fdc59956..0000000000 --- a/src/google/adk/sessions/migration/_schema_check.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Database schema version check utility.""" - -from __future__ import annotations - -import logging - -import sqlalchemy -from sqlalchemy import create_engine as create_sync_engine -from sqlalchemy import inspect -from sqlalchemy import text - -logger = logging.getLogger("google_adk." + __name__) - -SCHEMA_VERSION_KEY = "schema_version" -SCHEMA_VERSION_0_1_PICKLE = "0.1" -SCHEMA_VERSION_1_0_JSON = "1.0" -CURRENT_SCHEMA_VERSION = "1.0" - - -def _to_sync_url(db_url: str) -> str: - """Removes +driver from SQLAlchemy URL.""" - if "://" in db_url: - scheme, _, rest = db_url.partition("://") - if "+" in scheme: - dialect = scheme.split("+", 1)[0] - return f"{dialect}://{rest}" - return db_url - - -def get_version_and_v01_status_sync( - sess: sqlalchemy.orm.Session, -) -> tuple[str | None, bool]: - """Returns (version, is_v01) inspecting the database.""" - inspector = sqlalchemy.inspect(sess.get_bind()) - if inspector.has_table("adk_internal_metadata"): - try: - result = sess.execute( - text("SELECT value FROM adk_internal_metadata WHERE key = :key"), - {"key": SCHEMA_VERSION_KEY}, - ).fetchone() - # If table exists, with or without key, it's 1.0 or newer. - return (result[0] if result else SCHEMA_VERSION_1_0_JSON), False - except Exception as e: - logger.warning( - "Could not read from adk_internal_metadata: %s. Assuming v1.0.", - e, - ) - return SCHEMA_VERSION_1_0_JSON, False - - if inspector.has_table("events"): - try: - cols = {c["name"] for c in inspector.get_columns("events")} - if "actions" in cols and "event_data" not in cols: - return None, True # 0.1 schema - except Exception as e: - logger.warning("Could not inspect 'events' table columns: %s", e) - return None, False # New DB - - -def get_db_schema_version(db_url: str) -> str | None: - """Reads schema version from DB. - - Checks metadata table first, falls back to table structure for 0.1 vs 1.0. - """ - engine = None - try: - engine = create_sync_engine(_to_sync_url(db_url)) - inspector = inspect(engine) - - if inspector.has_table("adk_internal_metadata"): - with engine.connect() as connection: - result = connection.execute( - text("SELECT value FROM adk_internal_metadata WHERE key = :key"), - parameters={"key": SCHEMA_VERSION_KEY}, - ).fetchone() - # If table exists, with or without key, it's 1.0 or newer. - return result[0] if result else SCHEMA_VERSION_1_0_JSON - - # Metadata table doesn't exist, check for 0.1 schema. - # 0.1 schema has an 'events' table with an 'actions' column. - if inspector.has_table("events"): - try: - cols = {c["name"] for c in inspector.get_columns("events")} - if "actions" in cols and "event_data" not in cols: - return SCHEMA_VERSION_0_1_PICKLE - except Exception as e: - logger.warning("Could not inspect 'events' table columns: %s", e) - - # If no metadata table and not identifiable as 0.1, - # assume it is a new/empty DB requiring schema 1.0. - return SCHEMA_VERSION_1_0_JSON - except Exception as e: - logger.info( - "Could not determine schema version by inspecting database: %s." - " Assuming v1.0.", - e, - ) - return SCHEMA_VERSION_1_0_JSON - finally: - if engine: - engine.dispose() diff --git a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py b/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py deleted file mode 100644 index f33ef3f5cf..0000000000 --- a/src/google/adk/sessions/migration/migrate_from_sqlalchemy_pickle.py +++ /dev/null @@ -1,492 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Migration script from SQLAlchemy DB with Pickle Events to JSON schema.""" - -from __future__ import annotations - -import argparse -from datetime import datetime -from datetime import timezone -import json -import logging -import pickle -import sys -from typing import Any -from typing import Optional - -from google.adk.events.event import Event -from google.adk.events.event_actions import EventActions -from google.adk.sessions import _session_util -from google.adk.sessions import database_session_service as dss -from google.adk.sessions.migration import _schema_check -from google.genai import types -import sqlalchemy -from sqlalchemy import Boolean -from sqlalchemy import create_engine -from sqlalchemy import ForeignKeyConstraint -from sqlalchemy import func -from sqlalchemy import text -from sqlalchemy import Text -from sqlalchemy.dialects import mysql -from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.orm import DeclarativeBase -from sqlalchemy.orm import Mapped -from sqlalchemy.orm import mapped_column -from sqlalchemy.orm import sessionmaker -from sqlalchemy.types import PickleType -from sqlalchemy.types import String -from sqlalchemy.types import TypeDecorator - -logger = logging.getLogger("google_adk." + __name__) - - -# --- Old Schema Definitions --- -class DynamicPickleType(TypeDecorator): - """Represents a type that can be pickled.""" - - impl = PickleType - - def load_dialect_impl(self, dialect): - if dialect.name == "mysql": - return dialect.type_descriptor(mysql.LONGBLOB) - if dialect.name == "spanner+spanner": - from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType - - return dialect.type_descriptor(SpannerPickleType) - return self.impl - - def process_bind_param(self, value, dialect): - """Ensures the pickled value is a bytes object before passing it to the database dialect.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return pickle.dumps(value) - return value - - def process_result_value(self, value, dialect): - """Ensures the raw bytes from the database are unpickled back into a Python object.""" - if value is not None: - if dialect.name in ("spanner+spanner", "mysql"): - return pickle.loads(value) - return value - - -class OldBase(DeclarativeBase): - pass - - -class OldStorageSession(OldBase): - __tablename__ = "sessions" - app_name: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - user_id: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - id: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - state: Mapped[MutableDict[str, Any]] = mapped_column( - MutableDict.as_mutable(dss.DynamicJSON), default={} - ) - create_time: Mapped[datetime] = mapped_column( - dss.PreciseTimestamp, default=func.now() - ) - update_time: Mapped[datetime] = mapped_column( - dss.PreciseTimestamp, default=func.now(), onupdate=func.now() - ) - - -class OldStorageEvent(OldBase): - """Old storage event with pickle.""" - - __tablename__ = "events" - id: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - app_name: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - user_id: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - session_id: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - invocation_id: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_VARCHAR_LENGTH) - ) - author: Mapped[str] = mapped_column(String(dss.DEFAULT_MAX_VARCHAR_LENGTH)) - actions: Mapped[Any] = mapped_column(DynamicPickleType) - long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( - Text, nullable=True - ) - branch: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_VARCHAR_LENGTH), nullable=True - ) - timestamp: Mapped[datetime] = mapped_column( - dss.PreciseTimestamp, default=func.now() - ) - content: Mapped[dict[str, Any]] = mapped_column( - dss.DynamicJSON, nullable=True - ) - grounding_metadata: Mapped[dict[str, Any]] = mapped_column( - dss.DynamicJSON, nullable=True - ) - custom_metadata: Mapped[dict[str, Any]] = mapped_column( - dss.DynamicJSON, nullable=True - ) - usage_metadata: Mapped[dict[str, Any]] = mapped_column( - dss.DynamicJSON, nullable=True - ) - citation_metadata: Mapped[dict[str, Any]] = mapped_column( - dss.DynamicJSON, nullable=True - ) - partial: Mapped[bool] = mapped_column(Boolean, nullable=True) - turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True) - error_code: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_VARCHAR_LENGTH), nullable=True - ) - error_message: Mapped[str] = mapped_column(String(1024), nullable=True) - interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True) - input_transcription: Mapped[dict[str, Any]] = mapped_column( - dss.DynamicJSON, nullable=True - ) - output_transcription: Mapped[dict[str, Any]] = mapped_column( - dss.DynamicJSON, nullable=True - ) - __table_args__ = ( - ForeignKeyConstraint( - ["app_name", "user_id", "session_id"], - ["sessions.app_name", "sessions.user_id", "sessions.id"], - ondelete="CASCADE", - ), - ) - - @property - def long_running_tool_ids(self) -> set[str]: - return ( - set(json.loads(self.long_running_tool_ids_json)) - if self.long_running_tool_ids_json - else set() - ) - - -def _to_datetime_obj(val: Any) -> datetime | Any: - """Converts string to datetime if needed.""" - if isinstance(val, str): - try: - return datetime.strptime(val, "%Y-%m-%d %H:%M:%S.%f") - except ValueError: - try: - return datetime.strptime(val, "%Y-%m-%d %H:%M:%S") - except ValueError: - pass # return as is if not matching format - return val - - -def _row_to_event(row: dict) -> Event: - """Converts event row (dict) to event object, handling missing columns and deserializing.""" - - actions_val = row.get("actions") - actions = None - if actions_val is not None: - try: - if isinstance(actions_val, bytes): - actions = pickle.loads(actions_val) - else: # for spanner - it might return object directly - actions = actions_val - except Exception as e: - logger.warning( - f"Failed to unpickle actions for event {row.get('id')}: {e}" - ) - actions = None - - if actions and hasattr(actions, "model_dump"): - actions = EventActions().model_copy(update=actions.model_dump()) - elif isinstance(actions, dict): - actions = EventActions(**actions) - else: - actions = EventActions() - - def _safe_json_load(val): - data = None - if isinstance(val, str): - try: - data = json.loads(val) - except json.JSONDecodeError: - logger.warning(f"Failed to decode JSON for event {row.get('id')}") - return None - elif isinstance(val, dict): - data = val # for postgres JSONB - return data - - content_dict = _safe_json_load(row.get("content")) - grounding_metadata_dict = _safe_json_load(row.get("grounding_metadata")) - custom_metadata_dict = _safe_json_load(row.get("custom_metadata")) - usage_metadata_dict = _safe_json_load(row.get("usage_metadata")) - citation_metadata_dict = _safe_json_load(row.get("citation_metadata")) - input_transcription_dict = _safe_json_load(row.get("input_transcription")) - output_transcription_dict = _safe_json_load(row.get("output_transcription")) - - long_running_tool_ids_json = row.get("long_running_tool_ids_json") - long_running_tool_ids = set() - if long_running_tool_ids_json: - try: - long_running_tool_ids = set(json.loads(long_running_tool_ids_json)) - except json.JSONDecodeError: - logger.warning( - "Failed to decode long_running_tool_ids_json for event" - f" {row.get('id')}" - ) - long_running_tool_ids = set() - - event_id = row.get("id") - if not event_id: - raise ValueError("Event must have an id.") - timestamp = _to_datetime_obj(row.get("timestamp")) - if not timestamp: - raise ValueError(f"Event {event_id} must have a timestamp.") - - return Event( - id=event_id, - invocation_id=row.get("invocation_id", ""), - author=row.get("author", "agent"), - branch=row.get("branch"), - actions=actions, - timestamp=timestamp.replace(tzinfo=timezone.utc).timestamp(), - long_running_tool_ids=long_running_tool_ids, - partial=row.get("partial"), - turn_complete=row.get("turn_complete"), - error_code=row.get("error_code"), - error_message=row.get("error_message"), - interrupted=row.get("interrupted"), - custom_metadata=custom_metadata_dict, - content=_session_util.decode_model(content_dict, types.Content), - grounding_metadata=_session_util.decode_model( - grounding_metadata_dict, types.GroundingMetadata - ), - usage_metadata=_session_util.decode_model( - usage_metadata_dict, types.GenerateContentResponseUsageMetadata - ), - citation_metadata=_session_util.decode_model( - citation_metadata_dict, types.CitationMetadata - ), - input_transcription=_session_util.decode_model( - input_transcription_dict, types.Transcription - ), - output_transcription=_session_util.decode_model( - output_transcription_dict, types.Transcription - ), - ) - - -def _get_state_dict(state_val: Any) -> dict: - """Safely load dict from JSON string or return dict if already dict.""" - if isinstance(state_val, dict): - return state_val - if isinstance(state_val, str): - try: - return json.loads(state_val) - except json.JSONDecodeError: - logger.warning( - "Failed to parse state JSON string, defaulting to empty dict." - ) - return {} - return {} - - -class OldStorageAppState(OldBase): - __tablename__ = "app_states" - app_name: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - state: Mapped[MutableDict[str, Any]] = mapped_column( - MutableDict.as_mutable(dss.DynamicJSON), default={} - ) - update_time: Mapped[datetime] = mapped_column( - dss.PreciseTimestamp, default=func.now(), onupdate=func.now() - ) - - -class OldStorageUserState(OldBase): - __tablename__ = "user_states" - app_name: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - user_id: Mapped[str] = mapped_column( - String(dss.DEFAULT_MAX_KEY_LENGTH), primary_key=True - ) - state: Mapped[MutableDict[str, Any]] = mapped_column( - MutableDict.as_mutable(dss.DynamicJSON), default={} - ) - update_time: Mapped[datetime] = mapped_column( - dss.PreciseTimestamp, default=func.now(), onupdate=func.now() - ) - - -# --- Migration Logic --- -def migrate(source_db_url: str, dest_db_url: str): - """Migrates data from old pickle schema to new JSON schema.""" - logger.info(f"Connecting to source database: {source_db_url}") - try: - source_engine = create_engine(source_db_url) - SourceSession = sessionmaker(bind=source_engine) - except Exception as e: - logger.error(f"Failed to connect to source database: {e}") - raise RuntimeError(f"Failed to connect to source database: {e}") from e - - logger.info(f"Connecting to destination database: {dest_db_url}") - try: - dest_engine = create_engine(dest_db_url) - dss.Base.metadata.create_all(dest_engine) - DestSession = sessionmaker(bind=dest_engine) - except Exception as e: - logger.error(f"Failed to connect to destination database: {e}") - raise RuntimeError(f"Failed to connect to destination database: {e}") from e - - with SourceSession() as source_session, DestSession() as dest_session: - dest_session.merge( - dss.StorageMetadata( - key=_schema_check.SCHEMA_VERSION_KEY, - value=_schema_check.SCHEMA_VERSION_1_0_JSON, - ) - ) - dest_session.commit() - try: - inspector = sqlalchemy.inspect(source_engine) - - logger.info("Migrating app_states...") - if inspector.has_table("app_states"): - rows = ( - source_session.execute(text("SELECT * FROM app_states")) - .mappings() - .all() - ) - for row in rows: - dest_session.merge( - dss.StorageAppState( - app_name=row["app_name"], - state=_get_state_dict(row.get("state")), - update_time=_to_datetime_obj(row["update_time"]), - ) - ) - dest_session.commit() - logger.info(f"Migrated {len(rows)} app_states.") - else: - logger.info("No 'app_states' table found in source db.") - - logger.info("Migrating user_states...") - if inspector.has_table("user_states"): - rows = ( - source_session.execute(text("SELECT * FROM user_states")) - .mappings() - .all() - ) - for row in rows: - dest_session.merge( - dss.StorageUserState( - app_name=row["app_name"], - user_id=row["user_id"], - state=_get_state_dict(row.get("state")), - update_time=_to_datetime_obj(row["update_time"]), - ) - ) - dest_session.commit() - logger.info(f"Migrated {len(rows)} user_states.") - else: - logger.info("No 'user_states' table found in source db.") - - logger.info("Migrating sessions...") - if inspector.has_table("sessions"): - rows = ( - source_session.execute(text("SELECT * FROM sessions")) - .mappings() - .all() - ) - for row in rows: - dest_session.merge( - dss.StorageSession( - app_name=row["app_name"], - user_id=row["user_id"], - id=row["id"], - state=_get_state_dict(row.get("state")), - create_time=_to_datetime_obj(row["create_time"]), - update_time=_to_datetime_obj(row["update_time"]), - ) - ) - dest_session.commit() - logger.info(f"Migrated {len(rows)} sessions.") - else: - logger.info("No 'sessions' table found in source db.") - - logger.info("Migrating events...") - events = [] - if inspector.has_table("events"): - rows = ( - source_session.execute(text("SELECT * FROM events")) - .mappings() - .all() - ) - for row in rows: - try: - event_obj = _row_to_event(dict(row)) - new_event = dss.StorageEvent( - id=event_obj.id, - app_name=row["app_name"], - user_id=row["user_id"], - session_id=row["session_id"], - invocation_id=event_obj.invocation_id, - timestamp=datetime.fromtimestamp( - event_obj.timestamp, timezone.utc - ).replace(tzinfo=None), - event_data=event_obj.model_dump(mode="json", exclude_none=True), - ) - dest_session.merge(new_event) - events.append(new_event) - except Exception as e: - logger.warning( - f"Failed to migrate event row {row.get('id', 'N/A')}: {e}" - ) - dest_session.commit() - logger.info(f"Migrated {len(events)} events.") - else: - logger.info("No 'events' table found in source database.") - - logger.info("Migration completed successfully.") - except Exception as e: - logger.error(f"An error occurred during migration: {e}", exc_info=True) - dest_session.rollback() - raise RuntimeError(f"An error occurred during migration: {e}") from e - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description=( - "Migrate ADK sessions from SQLAlchemy Pickle format to JSON format." - ) - ) - parser.add_argument( - "--source_db_url", required=True, help="SQLAlchemy URL of source database" - ) - parser.add_argument( - "--dest_db_url", - required=True, - help="SQLAlchemy URL of destination database", - ) - args = parser.parse_args() - try: - migrate(args.source_db_url, args.dest_db_url) - except Exception as e: - logger.error(f"Migration failed: {e}") - sys.exit(1) diff --git a/src/google/adk/sessions/migration/migration_runner.py b/src/google/adk/sessions/migration/migration_runner.py deleted file mode 100644 index d7abbe41f9..0000000000 --- a/src/google/adk/sessions/migration/migration_runner.py +++ /dev/null @@ -1,128 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Migration runner to upgrade schemas to the latest version.""" - -from __future__ import annotations - -import logging -import os -import tempfile - -from google.adk.sessions.migration import _schema_check -from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle - -logger = logging.getLogger("google_adk." + __name__) - -# Migration map where key is start_version and value is -# (end_version, migration_function). -# Each key is a schema version, and its value is a tuple containing: -# (the schema version AFTER this migration step, the migration function to run). -# The migration function should accept (source_db_url, dest_db_url) as -# arguments. -MIGRATIONS = { - _schema_check.SCHEMA_VERSION_0_1_PICKLE: ( - _schema_check.SCHEMA_VERSION_1_0_JSON, - migrate_from_sqlalchemy_pickle.migrate, - ), -} -# The most recent schema version. The migration process stops once this version -# is reached. -LATEST_VERSION = _schema_check.CURRENT_SCHEMA_VERSION - - -def upgrade(source_db_url: str, dest_db_url: str): - """Migrates a database from its current version to the latest version. - - If the source database schema is older than the latest version, this - function applies migration scripts sequentially until the schema reaches the - LATEST_VERSION. - - If multiple migration steps are required, intermediate results are stored in - temporary SQLite database files. This means a multi-step migration - between other database types (e.g. PostgreSQL to PostgreSQL) will use - SQLite for intermediate steps. - - In-place migration (source_db_url == dest_db_url) is not supported, - as migrations always read from a source and write to a destination. - - Args: - source_db_url: The SQLAlchemy URL of the database to migrate from. - dest_db_url: The SQLAlchemy URL of the database to migrate to. This must be - different from source_db_url. - - Raises: - RuntimeError: If source_db_url and dest_db_url are the same, or if no - migration path is found. - """ - current_version = _schema_check.get_db_schema_version(source_db_url) - - if current_version == LATEST_VERSION: - logger.info( - f"Database {source_db_url} is already at latest version" - f" {LATEST_VERSION}. No migration needed." - ) - return - - if source_db_url == dest_db_url: - raise RuntimeError( - "In-place migration is not supported. " - "Please provide a different file for dest_db_url." - ) - - # Build the list of migration steps required to reach LATEST_VERSION. - migrations_to_run = [] - ver = current_version - while ver in MIGRATIONS and ver != LATEST_VERSION: - migrations_to_run.append(MIGRATIONS[ver]) - ver = MIGRATIONS[ver][0] - - if not migrations_to_run: - raise RuntimeError( - "Could not find migration path for schema version" - f" {current_version} to {LATEST_VERSION}." - ) - - temp_files = [] - in_url = source_db_url - try: - for i, (end_version, migrate_func) in enumerate(migrations_to_run): - is_last_step = i == len(migrations_to_run) - 1 - - if is_last_step: - out_url = dest_db_url - else: - # For intermediate steps, create a temporary SQLite DB to store the - # result. - fd, temp_path = tempfile.mkstemp(suffix=".db") - os.close(fd) - out_url = f"sqlite:///{temp_path}" - temp_files.append(temp_path) - logger.debug(f"Created temp db {out_url} for step {i+1}") - - logger.info( - f"Migrating from {in_url} to {out_url} (schema {end_version})..." - ) - migrate_func(in_url, out_url) - logger.info(f"Finished migration step to schema {end_version}.") - # The output of this step becomes the input for the next step. - in_url = out_url - finally: - # Ensure temporary files are cleaned up even if migration fails. - # Cleanup temp files - for path in temp_files: - try: - os.remove(path) - logger.debug(f"Removed temp db {path}") - except OSError as e: - logger.warning(f"Failed to remove temp db file {path}: {e}") diff --git a/src/google/adk/sessions/sqlite_session_service.py b/src/google/adk/sessions/sqlite_session_service.py index e0d44b3872..8ba6531f52 100644 --- a/src/google/adk/sessions/sqlite_session_service.py +++ b/src/google/adk/sessions/sqlite_session_service.py @@ -107,7 +107,7 @@ def __init__(self, db_path: str): f"Database {db_path} seems to use an old schema." " Please run the migration command to" " migrate it to the new schema. Example: `python -m" - " google.adk.sessions.migration.migrate_from_sqlalchemy_sqlite" + " google.adk.sessions.migrate_from_sqlalchemy_sqlite" f" --source_db_path {db_path} --dest_db_path" f" {db_path}.new` then backup {db_path} and rename" f" {db_path}.new to {db_path}." diff --git a/tests/unittests/sessions/migration/test_migrations.py b/tests/unittests/sessions/migration/test_migrations.py deleted file mode 100644 index 938387d29b..0000000000 --- a/tests/unittests/sessions/migration/test_migrations.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for migration scripts.""" - -from __future__ import annotations - -from datetime import datetime -from datetime import timezone - -from google.adk.events.event_actions import EventActions -from google.adk.sessions import database_session_service as dss -from google.adk.sessions.migration import _schema_check -from google.adk.sessions.migration import migrate_from_sqlalchemy_pickle as mfsp -import pytest -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - - -def test_migrate_from_sqlalchemy_pickle(tmp_path): - """Tests for migrate_from_sqlalchemy_pickle.""" - source_db_path = tmp_path / "source_pickle.db" - dest_db_path = tmp_path / "dest_json.db" - source_db_url = f"sqlite:///{source_db_path}" - dest_db_url = f"sqlite:///{dest_db_path}" - - # Setup source DB with old pickle schema - source_engine = create_engine(source_db_url) - mfsp.OldBase.metadata.create_all(source_engine) - SourceSession = sessionmaker(bind=source_engine) - source_session = SourceSession() - - # Populate source data - now = datetime.now(timezone.utc) - app_state = mfsp.OldStorageAppState( - app_name="app1", state={"akey": 1}, update_time=now - ) - user_state = mfsp.OldStorageUserState( - app_name="app1", user_id="user1", state={"ukey": 2}, update_time=now - ) - session = mfsp.OldStorageSession( - app_name="app1", - user_id="user1", - id="session1", - state={"skey": 3}, - create_time=now, - update_time=now, - ) - event = mfsp.OldStorageEvent( - id="event1", - app_name="app1", - user_id="user1", - session_id="session1", - invocation_id="invoke1", - author="user", - actions=EventActions(state_delta={"skey": 4}), - timestamp=now, - ) - source_session.add_all([app_state, user_state, session, event]) - source_session.commit() - source_session.close() - - mfsp.migrate(source_db_url, dest_db_url) - - # Verify destination DB - dest_engine = create_engine(dest_db_url) - DestSession = sessionmaker(bind=dest_engine) - dest_session = DestSession() - - metadata = dest_session.query(dss.StorageMetadata).first() - assert metadata is not None - assert metadata.key == _schema_check.SCHEMA_VERSION_KEY - assert metadata.value == _schema_check.SCHEMA_VERSION_1_0_JSON - - app_state_res = dest_session.query(dss.StorageAppState).first() - assert app_state_res is not None - assert app_state_res.app_name == "app1" - assert app_state_res.state == {"akey": 1} - - user_state_res = dest_session.query(dss.StorageUserState).first() - assert user_state_res is not None - assert user_state_res.user_id == "user1" - assert user_state_res.state == {"ukey": 2} - - session_res = dest_session.query(dss.StorageSession).first() - assert session_res is not None - assert session_res.id == "session1" - assert session_res.state == {"skey": 3} - - event_res = dest_session.query(dss.StorageEvent).first() - assert event_res is not None - assert event_res.id == "event1" - assert "state_delta" in event_res.event_data["actions"] - assert event_res.event_data["actions"]["state_delta"] == {"skey": 4} - - dest_session.close() diff --git a/tests/unittests/sessions/test_dynamic_pickle_type.py b/tests/unittests/sessions/test_dynamic_pickle_type.py new file mode 100644 index 0000000000..e4eb084f88 --- /dev/null +++ b/tests/unittests/sessions/test_dynamic_pickle_type.py @@ -0,0 +1,181 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pickle +from unittest import mock + +from google.adk.sessions.database_session_service import DynamicPickleType +import pytest +from sqlalchemy import create_engine +from sqlalchemy.dialects import mysql + + +@pytest.fixture +def pickle_type(): + """Fixture for DynamicPickleType instance.""" + return DynamicPickleType() + + +def test_load_dialect_impl_mysql(pickle_type): + """Test that MySQL dialect uses LONGBLOB.""" + # Mock the MySQL dialect + mock_dialect = mock.Mock() + mock_dialect.name = "mysql" + + # Mock the return value of type_descriptor + mock_longblob_type = mock.Mock() + mock_dialect.type_descriptor.return_value = mock_longblob_type + + impl = pickle_type.load_dialect_impl(mock_dialect) + + # Verify type_descriptor was called once with mysql.LONGBLOB + mock_dialect.type_descriptor.assert_called_once_with(mysql.LONGBLOB) + # Verify the return value is what we expect + assert impl == mock_longblob_type + + +def test_load_dialect_impl_spanner(pickle_type): + """Test that Spanner dialect uses SpannerPickleType.""" + # Mock the spanner dialect + mock_dialect = mock.Mock() + mock_dialect.name = "spanner+spanner" + + with mock.patch( + "google.cloud.sqlalchemy_spanner.sqlalchemy_spanner.SpannerPickleType" + ) as mock_spanner_type: + pickle_type.load_dialect_impl(mock_dialect) + mock_dialect.type_descriptor.assert_called_once_with(mock_spanner_type) + + +def test_load_dialect_impl_default(pickle_type): + """Test that other dialects use default PickleType.""" + engine = create_engine("sqlite:///:memory:") + dialect = engine.dialect + impl = pickle_type.load_dialect_impl(dialect) + # Should return the default impl (PickleType) + assert impl == pickle_type.impl + + +@pytest.mark.parametrize( + "dialect_name", + [ + pytest.param("mysql", id="mysql"), + pytest.param("spanner+spanner", id="spanner"), + ], +) +def test_process_bind_param_pickle_dialects(pickle_type, dialect_name): + """Test that MySQL and Spanner dialects pickle the value.""" + mock_dialect = mock.Mock() + mock_dialect.name = dialect_name + + test_data = {"key": "value", "nested": [1, 2, 3]} + result = pickle_type.process_bind_param(test_data, mock_dialect) + + # Should be pickled bytes + assert isinstance(result, bytes) + # Should be able to unpickle back to original + assert pickle.loads(result) == test_data + + +def test_process_bind_param_default(pickle_type): + """Test that other dialects return value as-is.""" + mock_dialect = mock.Mock() + mock_dialect.name = "sqlite" + + test_data = {"key": "value"} + result = pickle_type.process_bind_param(test_data, mock_dialect) + + # Should return value unchanged (SQLAlchemy's PickleType handles it) + assert result == test_data + + +def test_process_bind_param_none(pickle_type): + """Test that None values are handled correctly.""" + mock_dialect = mock.Mock() + mock_dialect.name = "mysql" + + result = pickle_type.process_bind_param(None, mock_dialect) + assert result is None + + +@pytest.mark.parametrize( + "dialect_name", + [ + pytest.param("mysql", id="mysql"), + pytest.param("spanner+spanner", id="spanner"), + ], +) +def test_process_result_value_pickle_dialects(pickle_type, dialect_name): + """Test that MySQL and Spanner dialects unpickle the value.""" + mock_dialect = mock.Mock() + mock_dialect.name = dialect_name + + test_data = {"key": "value", "nested": [1, 2, 3]} + pickled_data = pickle.dumps(test_data) + + result = pickle_type.process_result_value(pickled_data, mock_dialect) + + # Should be unpickled back to original + assert result == test_data + + +def test_process_result_value_default(pickle_type): + """Test that other dialects return value as-is.""" + mock_dialect = mock.Mock() + mock_dialect.name = "sqlite" + + test_data = {"key": "value"} + result = pickle_type.process_result_value(test_data, mock_dialect) + + # Should return value unchanged (SQLAlchemy's PickleType handles it) + assert result == test_data + + +def test_process_result_value_none(pickle_type): + """Test that None values are handled correctly.""" + mock_dialect = mock.Mock() + mock_dialect.name = "mysql" + + result = pickle_type.process_result_value(None, mock_dialect) + assert result is None + + +@pytest.mark.parametrize( + "dialect_name", + [ + pytest.param("mysql", id="mysql"), + pytest.param("spanner+spanner", id="spanner"), + ], +) +def test_roundtrip_pickle_dialects(pickle_type, dialect_name): + """Test full roundtrip for MySQL and Spanner: bind -> result.""" + mock_dialect = mock.Mock() + mock_dialect.name = dialect_name + + original_data = { + "string": "test", + "number": 42, + "list": [1, 2, 3], + "nested": {"a": 1, "b": 2}, + } + + # Simulate bind (Python -> DB) + bound_value = pickle_type.process_bind_param(original_data, mock_dialect) + assert isinstance(bound_value, bytes) + + # Simulate result (DB -> Python) + result_value = pickle_type.process_result_value(bound_value, mock_dialect) + assert result_value == original_data From 5947c41b554aca905e795b49aefc60b6c85be05f Mon Sep 17 00:00:00 2001 From: Bo Yang Date: Wed, 3 Dec 2025 13:40:21 -0800 Subject: [PATCH 62/63] chore: Update component owners Co-authored-by: Bo Yang PiperOrigin-RevId: 839896507 --- .../samples/adk_triaging_agent/agent.py | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/contributing/samples/adk_triaging_agent/agent.py b/contributing/samples/adk_triaging_agent/agent.py index d3e653f1d0..19096ce8eb 100644 --- a/contributing/samples/adk_triaging_agent/agent.py +++ b/contributing/samples/adk_triaging_agent/agent.py @@ -26,20 +26,21 @@ import requests LABEL_TO_OWNER = { + "a2a": "seanzhou1023", "agent engine": "yeesian", - "documentation": "polong-lin", - "services": "DeanChensj", - "question": "", - "mcp": "seanzhou1023", - "tools": "seanzhou1023", + "auth": "seanzhou1023", + "bq": "shobsi", + "core": "Jacksunwei", + "documentation": "joefernandez", "eval": "ankursharmas", - "live": "hangfei", - "models": "genquan9", + "live": "seanzhou1023", + "mcp": "seanzhou1023", + "models": "xuanyang15", + "services": "DeanChensj", + "tools": "xuanyang15", "tracing": "jawoszek", - "core": "Jacksunwei", "web": "wyf7107", - "a2a": "seanzhou1023", - "bq": "shobsi", + "workflow": "DeanChensj", } LABEL_GUIDELINES = """ @@ -65,6 +66,8 @@ Agent Engine concepts, do not use this label—choose "core" instead. - "a2a": Agent-to-agent workflows, coordination logic, or A2A protocol. - "bq": BigQuery integration or general issues related to BigQuery. + - "workflow": Workflow agents and workflow execution. + - "auth": Authentication or authorization issues. When unsure between labels, prefer the most specific match. If a label cannot be assigned confidently, do not call the labeling tool. @@ -265,6 +268,8 @@ def change_issue_type(issue_number: int, issue_type: str) -> dict[str, Any]: - If it's about Model Context Protocol (e.g. MCP tool, MCP toolset, MCP session management etc.), label it with both "mcp" and "tools". - If it's about A2A integrations or workflows, label it with "a2a". - If it's about BigQuery integrations, label it with "bq". + - If it's about workflow agents or workflow execution, label it with "workflow". + - If it's about authentication, label it with "auth". - If you can't find an appropriate labels for the issue, follow the previous instruction that starts with "IMPORTANT:". ## Triaging Workflow From 960b206752918d13f127a9d6ed8d21d34bcbc7fa Mon Sep 17 00:00:00 2001 From: Ankur Sharma Date: Wed, 3 Dec 2025 15:02:54 -0800 Subject: [PATCH 63/63] chore: Bumps version to v1.20.0 and updates CHANGELOG.md Co-authored-by: Ankur Sharma PiperOrigin-RevId: 839930279 --- CHANGELOG.md | 41 +++++++++++++++++++++++++++++++++++++++ src/google/adk/version.py | 2 +- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 72e0c7b19f..93dc505adc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,46 @@ # Changelog +## [1.20.0](https://github.com/google/adk-python/compare/v1.19.0...v1.20.0) (2025-12-01) + + +### Features +* **[Core]** + * Add enum constraint to `agent_name` for `transfer_to_agent` ([4a42d0d](https://github.com/google/adk-python/commit/4a42d0d9d81b7aab98371427f70a7707dbfb8bc4)) + * Add validation for unique sub-agent names ([#3557](https://github.com/google/adk-python/issues/3557)) ([2247a45](https://github.com/google/adk-python/commit/2247a45922afdf0a733239b619f45601d9b325ec)) + * Support streaming function call arguments in progressive SSE streaming feature ([786aaed](https://github.com/google/adk-python/commit/786aaed335e1ce64b7e92dff2f4af8316b2ef593)) + +* **[Models]** + * Enable multi-provider support for Claude and LiteLLM ([d29261a](https://github.com/google/adk-python/commit/d29261a3dc9c5a603feef27ea657c4a03bb8a089)) + +* **[Tools]** + * Create APIRegistryToolset to add tools from Cloud API registry to agent ([ec4ccd7](https://github.com/google/adk-python/commit/ec4ccd718feeadeb6b2b59fcc0e9ff29a4fd0bac)) + * Add an option to disallow propagating runner plugins to AgentTool runner ([777dba3](https://github.com/google/adk-python/commit/777dba3033a9a14667fb009ba017f648177be41d)) + +* **[Web]** + * Added an endpoint to list apps with details ([b57fe5f](https://github.com/google/adk-python/commit/b57fe5f4598925ec7592917bb32c7f0d6eca287a)) + + +### Bug Fixes + +* Allow image parts in user messages for Anthropic Claude ([5453b5b](https://github.com/google/adk-python/commit/5453b5bfdedc91d9d668c9eac39e3bb009a7bbbf)) +* Mark the Content as non-empty if its first part contains text or inline_data or file_data or func call/response ([631b583](https://github.com/google/adk-python/commit/631b58336d36bfd93e190582be34069613d38559)) +* Fixes double response processing issue in `base_llm_flow.py` where, in Bidi-streaming (live) mode, the multi-agent structure causes duplicated responses after tool calling. ([cf21ca3](https://github.com/google/adk-python/commit/cf21ca358478919207049695ba6b31dc6e0b2673)) +* Fix out of bounds error in _run_async_impl ([8fc6128](https://github.com/google/adk-python/commit/8fc6128b62ba576480d196d4a2597564fd0a7006)) +* Fix paths for public docs ([cd54f48](https://github.com/google/adk-python/commit/cd54f48fed0c87b54fb19743c9c75e790c5d9135)) +* Ensure request bodies without explicit names are named 'body' ([084c2de](https://github.com/google/adk-python/commit/084c2de0dac84697906e2b4beebf008bbd9ae8e1)), closes [#2213](https://github.com/google/adk-python/issues/2213) +* Optimize Stale Agent with GraphQL and Search API to resolve 429 Quota errors ([cb19d07](https://github.com/google/adk-python/commit/cb19d0714c90cd578551753680f39d8d6076c79b)) +* Update AgentTool to use Agent's description when input_schema is provided in FunctionDeclaration ([52674e7](https://github.com/google/adk-python/commit/52674e7fac6b7689f0e3871d41c4523e13471a7e)) +* Update LiteLLM system instruction role from "developer" to "system" ([2e1f730](https://github.com/google/adk-python/commit/2e1f730c3bc0eb454b76d7f36b7b9f1da7304cfe)), closes [#3657](https://github.com/google/adk-python/issues/3657) +* Update session last update time when appending events ([a3e4ad3](https://github.com/google/adk-python/commit/a3e4ad3cd130714affcaa880f696aeb498cd93af)), closes [#2721](https://github.com/google/adk-python/issues/2721) +* Update the retry_on_closed_resource decorator to retry on all errors ([a3aa077](https://github.com/google/adk-python/commit/a3aa07722a7de3e08807e86fd10f28938f0b267d)) +* Windows Path Handling and Normalize Cross-Platform Path Resolution in AgentLoader ([a1c09b7](https://github.com/google/adk-python/commit/a1c09b724bb37513eaabaff9643eeaa68014f14d)) + + +### Documentation + +* Add Code Wiki badge to README ([caf23ac](https://github.com/google/adk-python/commit/caf23ac49fe08bc7f625c61eed4635c26852c3ba)) + + ## [1.19.0](https://github.com/google/adk-python/compare/v1.18.0...v1.19.0) (2025-11-19) ### Features diff --git a/src/google/adk/version.py b/src/google/adk/version.py index 9478b2a547..a287db284c 100644 --- a/src/google/adk/version.py +++ b/src/google/adk/version.py @@ -13,4 +13,4 @@ # limitations under the License. # version: major.minor.patch -__version__ = "1.19.0" +__version__ = "1.20.0"