🌐 AI搜索 & 代理 主页
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
bfdd5d7
feat(branch-context): Implement token-set provenance for parallel age…
dannovikov Dec 8, 2025
ca0c0e5
Rename BranchContext -> Branch and update tests
dannovikov Dec 9, 2025
35361cf
Remove extra files created during debugging
dannovikov Dec 9, 2025
4612e05
restore accidentally deleted file
dannovikov Dec 9, 2025
1922f0e
restore accidentally deleted file
dannovikov Dec 9, 2025
191cbe0
run autoformat
dannovikov Dec 9, 2025
bf04470
add comment and delete extra file
dannovikov Dec 9, 2025
31d849d
make fork return one branch
dannovikov Dec 9, 2025
085158a
remove redundant code
dannovikov Dec 9, 2025
db566cc
update branch types
dannovikov Dec 9, 2025
63ed48d
remove extraneous comment
dannovikov Dec 9, 2025
3136ee8
event.branch docstring update
dannovikov Dec 9, 2025
65a74a9
rename tokenfactory to BranchTokenFactory
dannovikov Dec 9, 2025
f8ee681
branch optional
dannovikov Dec 9, 2025
e8c5c06
tighten up tests
dannovikov Dec 9, 2025
f21376b
remove unused import
dannovikov Dec 9, 2025
365d71b
tidy up
dannovikov Dec 9, 2025
2e52beb
revert changes to event converter
dannovikov Dec 9, 2025
7357ba1
tidying up
dannovikov Dec 9, 2025
34933c9
update comments
dannovikov Dec 9, 2025
33010b0
invocation context branch not optional
dannovikov Dec 10, 2025
4e3d0a9
format and rename
dannovikov Dec 10, 2025
e80faf4
Merge remote-tracking branch 'origin/main' into fix/event_filtering
dannovikov Dec 10, 2025
fc61208
Address gemini's comments
dannovikov Dec 10, 2025
05ae73d
Address gemini's comments 2
dannovikov Dec 10, 2025
250c3f1
fix failing unit test missing import
dannovikov Dec 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
4 changes: 4 additions & 0 deletions src/google/adk/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,10 @@ async def run_async(
async for event in agen:
yield event

# Propagate branch changes back to parent context.
if ctx.branch != parent_context.branch:
parent_context.branch = ctx.branch

if ctx.end_invocation:
return

Expand Down
163 changes: 163 additions & 0 deletions src/google/adk/agents/branch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Branch context for provenance-based event filtering in parallel agents."""

from __future__ import annotations

import threading
from typing import Optional

from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import model_serializer
from pydantic import PrivateAttr


class BranchTokenFactory:
"""Thread-safe global counter for branch tokens.

Each fork operation in a parallel agent execution creates new unique tokens
that are used to track provenance and determine event visibility across
branches WITHIN a single invocation.

The counter resets at the start of each invocation, ensuring tokens are
only used for parallel execution isolation within that invocation. Events
from previous invocations are always visible (branch filtering only applies
within current invocation).
"""

_lock = threading.Lock()
_next = 0

@classmethod
def new_token(cls) -> int:
"""Generate a new unique token.

Returns:
A unique integer token.
"""
with cls._lock:
cls._next += 1
return cls._next

@classmethod
def reset(cls) -> None:
"""Reset the counter to zero.

This should be called at the start of each invocation to ensure tokens
are fresh for that invocation's parallel execution tracking.
"""
with cls._lock:
cls._next = 0


class Branch(BaseModel):
"""Branch tracking using token sets for parallel agent execution.

Tracks event provenance across parallel and sequential agent execution.
Event visibility is determined by subset relationships: an event is visible
to a context if all the event's tokens are present in the context's token set.

Example:
Root context: {}
After fork(): child_0 has {1}, child_1 has {2}
After join: parent has {1, 2}

Events from child_0 (tokens={1}) are visible to parent (tokens={1,2})
because {1} ⊆ {1,2}.
"""

model_config = ConfigDict(
frozen=True, # Make instances immutable for hashing
arbitrary_types_allowed=True,
)
"""The pydantic model config."""

tokens: frozenset[int] = Field(default_factory=frozenset)
"""Set of integer tokens representing branch provenance.

If empty, represents the root context. Use frozenset for immutability
and to enable hashing for use in sets/dicts.
"""

@model_serializer
def serialize_model(self):
"""Custom serializer to convert frozenset to list for JSON serialization."""
return {'tokens': list(self.tokens)}

def fork(self) -> Branch:
"""Create a child context for parallel execution.

The child gets a unique new token added to the parent's token set.
This ensures:
1. Child can see parent's events (parent tokens ⊆ child tokens)
2. Siblings cannot see each other's events (sibling tokens are disjoint)

Returns:
A new Branch with parent.tokens ∪ {new_token}.
"""
new_token = BranchTokenFactory.new_token()
return Branch(tokens=self.tokens | {new_token})

def join(self, others: list[Branch]) -> Branch:
"""Merge token sets from parallel branches.

This is called when parallel execution completes and we need to merge
the provenance from all branches. The result contains the union of all
token sets, ensuring subsequent agents can see events from all branches.

Args:
others: List of other Branches to join with self.

Returns:
New Branch with union of all token sets.
"""
combined = set(self.tokens)
for ctx in others:
combined |= ctx.tokens
return Branch(tokens=frozenset(combined))

def can_see(self, event_ctx: Branch) -> bool:
"""Check if an event is visible from this context.

An event is visible if all of its tokens are present in the current
context's token set (subset relationship).

Args:
event_ctx: The Branch of the event to check.

Returns:
True if the event is visible, False otherwise.
"""
return event_ctx.tokens.issubset(self.tokens)

def __str__(self) -> str:
"""Human-readable string representation.

Returns:
String showing token set or "root" if empty.
"""
if not self.tokens:
return 'Branch(root)'
return f'Branch({sorted(self.tokens)})'

def __repr__(self) -> str:
"""Developer representation.

Returns:
String representation for debugging.
"""
return str(self)
18 changes: 8 additions & 10 deletions src/google/adk/agents/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from .active_streaming_tool import ActiveStreamingTool
from .base_agent import BaseAgent
from .base_agent import BaseAgentState
from .branch import Branch
from .context_cache_config import ContextCacheConfig
from .live_request_queue import LiveRequestQueue
from .run_config import RunConfig
Expand Down Expand Up @@ -149,15 +150,8 @@ class InvocationContext(BaseModel):

invocation_id: str
"""The id of this invocation context. Readonly."""
branch: Optional[str] = None
"""The branch of the invocation context.

The format is like agent_1.agent_2.agent_3, where agent_1 is the parent of
agent_2, and agent_2 is the parent of agent_3.

Branch is used when multiple sub-agents shouldn't see their peer agents'
conversation history.
"""
branch: Branch = Field(default_factory=Branch)
"""The branch context tracking event provenance for visibility filtering."""
agent: BaseAgent
"""The current agent of this invocation context. Readonly."""
user_content: Optional[types.Content] = None
Expand Down Expand Up @@ -349,7 +343,11 @@ def _get_events(
if event.invocation_id == self.invocation_id
]
if current_branch:
results = [event for event in results if event.branch == self.branch]
results = [
event
for event in results
if event.branch is None or event.branch == self.branch
]
return results

def should_pause_invocation(self, event: Event) -> bool:
Expand Down
16 changes: 10 additions & 6 deletions src/google/adk/agents/parallel_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .base_agent import BaseAgent
from .base_agent import BaseAgentState
from .base_agent_config import BaseAgentConfig
from .branch import Branch
from .invocation_context import InvocationContext
from .parallel_agent_config import ParallelAgentConfig

Expand All @@ -39,12 +40,8 @@ def _create_branch_ctx_for_sub_agent(
) -> InvocationContext:
"""Create isolated branch for every sub-agent."""
invocation_context = invocation_context.model_copy()
branch_suffix = f'{agent.name}.{sub_agent.name}'
invocation_context.branch = (
f'{invocation_context.branch}.{branch_suffix}'
if invocation_context.branch
else branch_suffix
)
parent_branch = invocation_context.branch or Branch()
invocation_context.branch = parent_branch.fork()
return invocation_context


Expand Down Expand Up @@ -173,9 +170,11 @@ async def _run_async_impl(
yield self._create_agent_state_event(ctx)

agent_runs = []
sub_agent_contexts = []
# Prepare and collect async generators for each sub-agent.
for sub_agent in self.sub_agents:
sub_agent_ctx = _create_branch_ctx_for_sub_agent(self, sub_agent, ctx)
sub_agent_contexts.append(sub_agent_ctx)

# Only include sub-agents that haven't finished in a previous run.
if not sub_agent_ctx.end_of_agents.get(sub_agent.name):
Expand All @@ -197,6 +196,11 @@ async def _run_async_impl(
if pause_invocation:
return

# Join all child branches back together after parallel execution completes
parent_branch = ctx.branch or Branch()
joined_branch = parent_branch.join([c.branch for c in sub_agent_contexts])
ctx.branch = joined_branch

# Once all sub-agents are done, mark the ParallelAgent as final.
if ctx.is_resumable and all(
ctx.end_of_agents.get(sub_agent.name) for sub_agent in self.sub_agents
Expand Down
12 changes: 3 additions & 9 deletions src/google/adk/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pydantic import ConfigDict
from pydantic import Field

from ..agents.branch import Branch
from ..models.llm_response import LlmResponse
from .event_actions import EventActions

Expand Down Expand Up @@ -56,15 +57,8 @@ class Event(LlmResponse):
Agent client will know from this field about which function call is long running.
only valid for function call event
"""
branch: Optional[str] = None
"""The branch of the event.

The format is like agent_1.agent_2.agent_3, where agent_1 is the parent of
agent_2, and agent_2 is the parent of agent_3.

Branch is used when multiple sub-agent shouldn't see their peer agents'
conversation history.
"""
branch: Optional[Branch] = None
"""The branch context of the event. Used for provenance-based event filtering in parallel agents."""

# The following are computed fields.
# Do not assign the ID. It will be assigned by the session.
Expand Down
1 change: 1 addition & 0 deletions src/google/adk/flows/llm_flows/audio_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ async def _flush_cache_to_services(
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
author=audio_cache[0].role,
branch=invocation_context.branch,
content=types.Content(
role=audio_cache[0].role,
parts=[
Expand Down
1 change: 1 addition & 0 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def get_author_for_event(llm_response):
id=Event.new_id(),
invocation_id=invocation_context.invocation_id,
author=get_author_for_event(llm_response),
branch=invocation_context.branch,
)

async with Aclosing(
Expand Down
Loading