🌐 AI搜索 & 代理 主页
Skip to content

Conversation

@rkuester
Copy link
Contributor

This is a draft PR for running CI tests on the full change. The commits
along this branch will be individually submitted for review.

See the linked issue for a description of the change.

BUG=implements #3256

Implement unified module for creating, reading, and modifying TFLite
models with a clean API. The module eliminates manual index tracking
and buffer management through automatic bookkeeping, supporting both
declarative and imperative construction styles.

Core design uses first-class Buffer objects that can be shared between
tensors, with automatic deduplication during build. Tensors reference
Buffers directly, matching the TFLite schema structure. The compiler
automatically extracts inline tensor declarations, builds operator code
tables, and handles index assignment according to TFLite conventions.

Supports quantization parameters (per-tensor and per-channel), metadata
key-value pairs, and read-modify-write workflows. The read() function
preserves the object graph structure, enabling models to be read,
modified, and rebuilt.

Add comprehensive test coverage for core functionality, advanced
features, quantization, and modification workflows.
…_editor

Replace model_facade with model_editor in compress.py and tests.
model_editor provides a cleaner API with better buffer and metadata
handling. Buffers appended during compression are automatically indexed,
and quantization parameters are accessed through a wrapper object.

Update BUILD dependencies accordingly.
Remove model_facade module and its tests as they are superseded by
model_editor.
…ess_test

Replace dictionary-based test_models.build() with model_editor's
declarative API. Add _build_test_model() function that uses model_editor
to create the same test model more cleanly.
Remove test_models module and its tests as they are superseded by
model_editor.
Add DecodeType class to replace raw integer decode_type field with
named constants and factory methods. Provides predefined constants for
built-in types (LUT, HUFFMAN, PRUNING) and a factory method for custom
types (128-255).

Custom types are automatically named with CUSTOM_{code} prefix for
clarity in debugging. The class supports serialization via __int__()
and comparison with both DecodeType objects and integers.

Update DecodeCommonMetadata to use DecodeType and update tests to use
named constants.

Add decode module BUILD targets.
Add dataclass placeholders for future compression methods. These will
be used by the plugin architecture to dispatch to compression-specific
implementations.
Factor out compression method parsing into a dedicated function that
dispatches on the YAML key. This enables parse_yaml to iterate over
multiple compression methods per tensor and makes adding new
compression types straightforward.
Define the plugin interface for compression methods. Each compressor
implements the Compressor protocol with a compress() method that
returns encoded data and ancillary data. CompressionError provides
a common exception type for compression failures.
Extract LUT compression logic from compress.py into a dedicated plugin
module. The LutCompressor class implements the Compressor protocol,
producing packed indices and ancillary data in the format expected by
the C++ DECODE kernel.
Add placeholder implementations that raise CompressionError when
invoked. These validate the plugin architecture and will be replaced
with working implementations later.
Implement graph modification to insert DECODE operators before
consumers of compressed tensors. Each compressed tensor gets a DECODE
operator with two inputs (encoded tensor and ancillary data tensor)
and one output (decompressed tensor). Consumer operators are rewired
to use the DECODE output.
Replace monolithic compression logic with a dispatch table that routes
compression requests to plugin modules based on the spec's compression
method type. After compressing tensors, insert DECODE operators into
the model graph.

The old metadata flatbuffer approach is removed in favor of the DECODE
operator format.
The TFLM interpreter requires subgraph inputs/outputs to be set in the
flatbuffer to know which tensors are model inputs and outputs. Without
these, models built with model_editor cannot be executed.

Add inputs and outputs fields to Subgraph dataclass, populate them in
_compile_subgraph when building, and read them back in read().
Add tests that compress models with LUT compression, run them through
the TFLM Python interpreter, and verify outputs match uncompressed
originals. Also verify DECODE operators are inserted and that compressed
models are smaller than originals.

Tests only run when compression is enabled (--//:with_compression).
Placeholder tests for Huffman and Pruning are skipped until implemented.
Add alt_decompression_memory_size parameter to the Python interpreter
API. When non-zero, allocates a separate memory region for DECODE
operator outputs and calls SetDecompressionMemory before AllocateTensors.

SetDecompressionMemory stores a pointer to its initializer_list
argument, requiring the list to outlive the interpreter. Per C++
standard, an initializer_list's backing array lifetime is only extended
to match the list's when initialized in a declaration, not when
assigned. This makes the API difficult to use correctly.
Add test for shared compressed tensors with alternate decompression
memory. The test is marked expectedFailure to document the current
mismatch between interpreter and DECODE insertion: the interpreter's
alt decompression memory resets allocations for each DECODE, but the
insertion code shares one DECODE output among all consumers.

The workaround is to insert a separate DECODE before each consumer.
The expectedFailure decorator should be removed once this is
implemented.
…mory

Insert a separate DECODE immediately before each consumer of a
compressed tensor, rather than sharing one DECODE output among all
consumers.

The interpreter's alternate decompression memory resets its allocation
offset for each DECODE's Prepare, causing all DECODE outputs to be
allocated at the same address. If two consumers share one DECODE and
another DECODE runs between them, the intervening DECODE overwrites the
shared output, corrupting data for the second consumer.

Update test expectations to reflect the new DECODE-per-consumer
behavior and change the integration test from expected-failure to
expected-pass.
Add tests demonstrating bugs in model_editor.read() when parsing models
with None values for tensor shape, operator inputs, or operator outputs.
These edge cases can occur in real models from the TFLite converter but
cause TypeError crashes in the current implementation.

Tests construct models using the low-level TFLite schema to reproduce
these conditions. Marked as expectedFailure until the fix is applied.
The TFLite flatbuffer schema allows None values for tensor shape
(representing scalars) and operator inputs/outputs (for certain ops).
Handle these cases in read() to avoid TypeError when iterating.

Remove expectedFailure decorators from edge case tests now that the
fix is applied.
Comment on lines +204 to +208
ancillary_tensor = _create_ancillary_tensor(
info.ancillary_data,
info.tensor,
)
subgraph.tensors.append(ancillary_tensor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to create too many ancillary tensor copies. Each encoded tensor can reuse the associated ancillary tensor. Should _create_ancillary_tensor have a cache keyed on the original tensor?

subgraph.operators.insert(insert_pos, decode_op)

# Rewire only this consumer to use the decoded output
_rewire_consumers([consumer], info.tensor, output_tensor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code to comply with "Input tensor pairs" and "Tensor rewriting" in the design doc seem to be missing. Should there be another loop after the loop with _rewire_consumers that iterates compression_results and rewrites the original tensor (removing quantization, setting type to UINT8, changing shape to single rank and equal to the number of bytes in the encoded tensor)?

Comment on lines +341 to +342
if len(scales) > 1 and fb_quant.quantizedDimension is not None:
axis = fb_quant.quantizedDimension
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The axis should always be copied if it exists. It is perfectly valid for len(scales) to be 1, the quantized_dimension to be 3, and shape[3] is equal to 1.

scales = list(fb_quant.scale)
zeros = list(
fb_quant.zeroPoint
) if fb_quant.zeroPoint is not None else [0] * len(scales)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The zero_point vector should never be expanded, only copied (this is a converter optimization). It is valid to generate quantization data where the length of the scale vector is > 1 and the zero_point vector has length 1 (this is a converter optimization). This case is already handled in TFLM.

Comment on lines +395 to +399
tflite.TensorType.INT8: np.int8,
tflite.TensorType.INT16: np.int16,
tflite.TensorType.INT32: np.int32,
tflite.TensorType.UINT8: np.uint8,
tflite.TensorType.FLOAT32: np.float32,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing INT64

Comment on lines +100 to +104
shape: tuple
dtype: tflite.TensorType
buffer: Optional[Buffer] = None
quantization: Optional[Quantization] = None
name: Optional[str] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_variable must be maintained

Comment on lines +205 to +207
builtin_code: tflite.BuiltinOperator
custom_code: Optional[str] = None
version: int = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

must maintain deprecated_builtin_code, and it must interoperate as per the schema.fbs comments with builtin_code.

Comment on lines +213 to +219
opcode: Union[tflite.BuiltinOperator, int]
inputs: List[Tensor]
outputs: List[Tensor]
custom_code: Optional[str] = None

# Set when reading from existing model
opcode_index: Optional[int] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Must maintain the following fields:

  builtin_options:BuiltinOptions;
  custom_options:[ubyte];
  custom_options_format:CustomOptionsFormat;

  // A list of indices to the subgraph's "tensors" that are internal to an Op.
  // Internal tensors are those that do not flow in or out of the operation,
  // but instead are part of internal computation. As such, the operation's
  // implementation may manage its memory more efficiently. They are needed
  // however (i.e. not just an implementation detail) since they are part of the
  // computation, which may require relevant metadata such as quantization
  // parameters.
  intermediates:[int];

  // Flatbuffers union struct has a 128 elements limit in JAVA, so a second
  // union is added, in the case of where BuitlinOptions2 runs out, a third
  // one can be added
  builtin_options_2 : BuiltinOptions2;

  // Index into operators_debug_metadata list.
  debug_metadata_index: int = -1;

builtin_options and builtin_options_2 must interoperate correctly.

Comment on lines +227 to +231
tensors: List[Tensor] = field(default_factory=list)
operators: List[Operator] = field(default_factory=list)
inputs: List[Tensor] = field(default_factory=list)
outputs: List[Tensor] = field(default_factory=list)
name: Optional[str] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug_metadata_index should be maintained.

Comment on lines +262 to +267
subgraphs: List[Subgraph] = field(default_factory=list)
buffers: _BufferList = field(
default_factory=_BufferList) # Auto-sets buffer.index on append
operator_codes: List[OperatorCode] = field(default_factory=list)
metadata: dict = field(default_factory=dict)
description: Optional[str] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

must maintain the following fields:

version
signature_defs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants