-
Notifications
You must be signed in to change notification settings - Fork 970
feat(compression): update tooling to use DECODE operators #3257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Implement unified module for creating, reading, and modifying TFLite models with a clean API. The module eliminates manual index tracking and buffer management through automatic bookkeeping, supporting both declarative and imperative construction styles. Core design uses first-class Buffer objects that can be shared between tensors, with automatic deduplication during build. Tensors reference Buffers directly, matching the TFLite schema structure. The compiler automatically extracts inline tensor declarations, builds operator code tables, and handles index assignment according to TFLite conventions. Supports quantization parameters (per-tensor and per-channel), metadata key-value pairs, and read-modify-write workflows. The read() function preserves the object graph structure, enabling models to be read, modified, and rebuilt. Add comprehensive test coverage for core functionality, advanced features, quantization, and modification workflows.
…_editor Replace model_facade with model_editor in compress.py and tests. model_editor provides a cleaner API with better buffer and metadata handling. Buffers appended during compression are automatically indexed, and quantization parameters are accessed through a wrapper object. Update BUILD dependencies accordingly.
Remove model_facade module and its tests as they are superseded by model_editor.
…ess_test Replace dictionary-based test_models.build() with model_editor's declarative API. Add _build_test_model() function that uses model_editor to create the same test model more cleanly.
Remove test_models module and its tests as they are superseded by model_editor.
Add DecodeType class to replace raw integer decode_type field with
named constants and factory methods. Provides predefined constants for
built-in types (LUT, HUFFMAN, PRUNING) and a factory method for custom
types (128-255).
Custom types are automatically named with CUSTOM_{code} prefix for
clarity in debugging. The class supports serialization via __int__()
and comparison with both DecodeType objects and integers.
Update DecodeCommonMetadata to use DecodeType and update tests to use
named constants.
Add decode module BUILD targets.
Add dataclass placeholders for future compression methods. These will be used by the plugin architecture to dispatch to compression-specific implementations.
Factor out compression method parsing into a dedicated function that dispatches on the YAML key. This enables parse_yaml to iterate over multiple compression methods per tensor and makes adding new compression types straightforward.
Define the plugin interface for compression methods. Each compressor implements the Compressor protocol with a compress() method that returns encoded data and ancillary data. CompressionError provides a common exception type for compression failures.
Extract LUT compression logic from compress.py into a dedicated plugin module. The LutCompressor class implements the Compressor protocol, producing packed indices and ancillary data in the format expected by the C++ DECODE kernel.
Add placeholder implementations that raise CompressionError when invoked. These validate the plugin architecture and will be replaced with working implementations later.
Implement graph modification to insert DECODE operators before consumers of compressed tensors. Each compressed tensor gets a DECODE operator with two inputs (encoded tensor and ancillary data tensor) and one output (decompressed tensor). Consumer operators are rewired to use the DECODE output.
Replace monolithic compression logic with a dispatch table that routes compression requests to plugin modules based on the spec's compression method type. After compressing tensors, insert DECODE operators into the model graph. The old metadata flatbuffer approach is removed in favor of the DECODE operator format.
The TFLM interpreter requires subgraph inputs/outputs to be set in the flatbuffer to know which tensors are model inputs and outputs. Without these, models built with model_editor cannot be executed. Add inputs and outputs fields to Subgraph dataclass, populate them in _compile_subgraph when building, and read them back in read().
Add tests that compress models with LUT compression, run them through the TFLM Python interpreter, and verify outputs match uncompressed originals. Also verify DECODE operators are inserted and that compressed models are smaller than originals. Tests only run when compression is enabled (--//:with_compression). Placeholder tests for Huffman and Pruning are skipped until implemented.
Add alt_decompression_memory_size parameter to the Python interpreter API. When non-zero, allocates a separate memory region for DECODE operator outputs and calls SetDecompressionMemory before AllocateTensors. SetDecompressionMemory stores a pointer to its initializer_list argument, requiring the list to outlive the interpreter. Per C++ standard, an initializer_list's backing array lifetime is only extended to match the list's when initialized in a declaration, not when assigned. This makes the API difficult to use correctly.
Add test for shared compressed tensors with alternate decompression memory. The test is marked expectedFailure to document the current mismatch between interpreter and DECODE insertion: the interpreter's alt decompression memory resets allocations for each DECODE, but the insertion code shares one DECODE output among all consumers. The workaround is to insert a separate DECODE before each consumer. The expectedFailure decorator should be removed once this is implemented.
…mory Insert a separate DECODE immediately before each consumer of a compressed tensor, rather than sharing one DECODE output among all consumers. The interpreter's alternate decompression memory resets its allocation offset for each DECODE's Prepare, causing all DECODE outputs to be allocated at the same address. If two consumers share one DECODE and another DECODE runs between them, the intervening DECODE overwrites the shared output, corrupting data for the second consumer. Update test expectations to reflect the new DECODE-per-consumer behavior and change the integration test from expected-failure to expected-pass.
Add tests demonstrating bugs in model_editor.read() when parsing models with None values for tensor shape, operator inputs, or operator outputs. These edge cases can occur in real models from the TFLite converter but cause TypeError crashes in the current implementation. Tests construct models using the low-level TFLite schema to reproduce these conditions. Marked as expectedFailure until the fix is applied.
The TFLite flatbuffer schema allows None values for tensor shape (representing scalars) and operator inputs/outputs (for certain ops). Handle these cases in read() to avoid TypeError when iterating. Remove expectedFailure decorators from edge case tests now that the fix is applied.
| ancillary_tensor = _create_ancillary_tensor( | ||
| info.ancillary_data, | ||
| info.tensor, | ||
| ) | ||
| subgraph.tensors.append(ancillary_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This appears to create too many ancillary tensor copies. Each encoded tensor can reuse the associated ancillary tensor. Should _create_ancillary_tensor have a cache keyed on the original tensor?
| subgraph.operators.insert(insert_pos, decode_op) | ||
|
|
||
| # Rewire only this consumer to use the decoded output | ||
| _rewire_consumers([consumer], info.tensor, output_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code to comply with "Input tensor pairs" and "Tensor rewriting" in the design doc seem to be missing. Should there be another loop after the loop with _rewire_consumers that iterates compression_results and rewrites the original tensor (removing quantization, setting type to UINT8, changing shape to single rank and equal to the number of bytes in the encoded tensor)?
| if len(scales) > 1 and fb_quant.quantizedDimension is not None: | ||
| axis = fb_quant.quantizedDimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The axis should always be copied if it exists. It is perfectly valid for len(scales) to be 1, the quantized_dimension to be 3, and shape[3] is equal to 1.
| scales = list(fb_quant.scale) | ||
| zeros = list( | ||
| fb_quant.zeroPoint | ||
| ) if fb_quant.zeroPoint is not None else [0] * len(scales) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The zero_point vector should never be expanded, only copied (this is a converter optimization). It is valid to generate quantization data where the length of the scale vector is > 1 and the zero_point vector has length 1 (this is a converter optimization). This case is already handled in TFLM.
| tflite.TensorType.INT8: np.int8, | ||
| tflite.TensorType.INT16: np.int16, | ||
| tflite.TensorType.INT32: np.int32, | ||
| tflite.TensorType.UINT8: np.uint8, | ||
| tflite.TensorType.FLOAT32: np.float32, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing INT64
| shape: tuple | ||
| dtype: tflite.TensorType | ||
| buffer: Optional[Buffer] = None | ||
| quantization: Optional[Quantization] = None | ||
| name: Optional[str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_variable must be maintained
| builtin_code: tflite.BuiltinOperator | ||
| custom_code: Optional[str] = None | ||
| version: int = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
must maintain deprecated_builtin_code, and it must interoperate as per the schema.fbs comments with builtin_code.
| opcode: Union[tflite.BuiltinOperator, int] | ||
| inputs: List[Tensor] | ||
| outputs: List[Tensor] | ||
| custom_code: Optional[str] = None | ||
|
|
||
| # Set when reading from existing model | ||
| opcode_index: Optional[int] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Must maintain the following fields:
builtin_options:BuiltinOptions;
custom_options:[ubyte];
custom_options_format:CustomOptionsFormat;
// A list of indices to the subgraph's "tensors" that are internal to an Op.
// Internal tensors are those that do not flow in or out of the operation,
// but instead are part of internal computation. As such, the operation's
// implementation may manage its memory more efficiently. They are needed
// however (i.e. not just an implementation detail) since they are part of the
// computation, which may require relevant metadata such as quantization
// parameters.
intermediates:[int];
// Flatbuffers union struct has a 128 elements limit in JAVA, so a second
// union is added, in the case of where BuitlinOptions2 runs out, a third
// one can be added
builtin_options_2 : BuiltinOptions2;
// Index into operators_debug_metadata list.
debug_metadata_index: int = -1;
builtin_options and builtin_options_2 must interoperate correctly.
| tensors: List[Tensor] = field(default_factory=list) | ||
| operators: List[Operator] = field(default_factory=list) | ||
| inputs: List[Tensor] = field(default_factory=list) | ||
| outputs: List[Tensor] = field(default_factory=list) | ||
| name: Optional[str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
debug_metadata_index should be maintained.
| subgraphs: List[Subgraph] = field(default_factory=list) | ||
| buffers: _BufferList = field( | ||
| default_factory=_BufferList) # Auto-sets buffer.index on append | ||
| operator_codes: List[OperatorCode] = field(default_factory=list) | ||
| metadata: dict = field(default_factory=dict) | ||
| description: Optional[str] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
must maintain the following fields:
version
signature_defs
This is a draft PR for running CI tests on the full change. The commits
along this branch will be individually submitted for review.
See the linked issue for a description of the change.
BUG=implements #3256