FEAT implement batching for memory interface#1325
FEAT implement batching for memory interface#1325maifeeulasad wants to merge 6 commits intoAzure:mainfrom
Conversation
- add `_SQLITE_MAX_BIND_VARS` constant to handle SQLite limits - implement batching in `get_scores()` for `score_ids` parameter - implement batching in `get_message_pieces()` for `prompt_ids`, `original_values`, `converted_values`, `converted_value_sha256` - implement batching in `get_attack_results()` for `attack_result_ids`, `objective_sha256 parameters` - implement batching in `get_scenario_results()` for `scenario_result_ids` - refactor necessary filter conditions across batched queries - handle empty list edge cases
- independent batching of all parameters - extracted batch in condition - helper functions
…re#845) - tests focusing on independent batching of all parameters
There was a problem hiding this comment.
Pull request overview
This PR aims to address scaling limits in the PyRIT MemoryInterface by introducing “batched” filtering logic for large IN (...) queries, along with unit tests intended to validate behavior when many IDs/values are passed.
Changes:
- Added
_SQLITE_MAX_BIND_VARSand_batched_in_conditionhelper inMemoryInterfaceand applied it to several query filters. - Updated multiple retrieval methods (e.g., scores, message pieces, attack results, scenario results) to use the new helper for
INfiltering. - Added a new unit test module to exercise “batching” behavior and large-input queries against the SQLite-backed memory instance.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
pyrit/memory/memory_interface.py |
Introduces a helper intended to avoid bind-variable limits and applies it across several query paths. |
tests/unit/memory/memory_interface/test_batching_scale.py |
Adds tests intended to validate that large ID/value inputs don’t break memory queries. |
| # ref: https://www.sqlite.org/limits.html | ||
| # Lowest default maximum is 999, intentionally setting it to half | ||
| _SQLITE_MAX_BIND_VARS = 500 |
There was a problem hiding this comment.
_SQLITE_MAX_BIND_VARS is defined in the base MemoryInterface, but it’s SQLite-specific and is now used to batch queries for Azure SQL as well. This can be confusing and makes it harder to tune batching per-backend; consider making the limit a backend-specific attribute (e.g., on SQLiteMemory / AzureSQLMemory) or deriving it from the SQLAlchemy dialect, with a safe default fallback.
| # Handle score_ids with batched queries if needed | ||
| if score_ids: | ||
| entries = self._execute_batched_query( | ||
| ScoreEntry, | ||
| batch_column=ScoreEntry.id, | ||
| batch_values=list(score_ids), | ||
| other_conditions=conditions, | ||
| ) | ||
| return [entry.get_score() for entry in entries] | ||
|
|
||
| # No score_ids specified - use regular query | ||
| if not conditions: | ||
| return [] |
There was a problem hiding this comment.
get_scores treats score_ids=[] the same as “no score_ids filter” (because the check is if score_ids:). If callers pass an explicit empty list alongside other filters, the expected intersection is typically empty; consider mirroring the get_attack_results/get_scenario_results pattern by returning [] when score_ids is not None and len(score_ids) == 0.
| # Identify list parameters and whether they need batching | ||
| list_params: list[tuple[InstrumentedAttribute[Any], Sequence[Any], str]] = [] | ||
| if attack_result_ids: | ||
| list_params.append((AttackResultEntry.id, list(attack_result_ids), "id")) | ||
| if objective_sha256: | ||
| list_params.append((AttackResultEntry.objective_sha256, list(objective_sha256), "objective_sha256")) | ||
|
|
||
| # If no list params, execute simple query | ||
| if not list_params: | ||
| entries: Sequence[AttackResultEntry] = self._query_entries( | ||
| AttackResultEntry, conditions=and_(*conditions) if conditions else None | ||
| ) | ||
| return [entry.get_attack_result() for entry in entries] | ||
|
|
||
| # Find which list params need batching | ||
| large_params = [(col, vals, name) for col, vals, name in list_params if len(vals) > _SQLITE_MAX_BIND_VARS] | ||
| small_params = [(col, vals, name) for col, vals, name in list_params if len(vals) <= _SQLITE_MAX_BIND_VARS] | ||
|
|
||
| # Add small list params to conditions | ||
| for col, vals, _ in small_params: | ||
| conditions.append(col.in_(vals)) | ||
|
|
||
| # If no large params, execute simple query | ||
| if not large_params: | ||
| entries = self._query_entries(AttackResultEntry, conditions=and_(*conditions) if conditions else None) | ||
| return [entry.get_attack_result() for entry in entries] | ||
|
|
||
| # Batch on the first large parameter | ||
| batch_col, batch_vals, _ = large_params[0] | ||
| other_large_params = large_params[1:] | ||
|
|
||
| # Execute batched query | ||
| entries = self._execute_batched_query( | ||
| AttackResultEntry, | ||
| batch_column=batch_col, | ||
| batch_values=batch_vals, | ||
| other_conditions=conditions, | ||
| ) |
There was a problem hiding this comment.
Batching logic was added to get_attack_results (including multi-list-parameter handling and Python-side filtering for additional large params), but there are no unit tests exercising the > _SQLITE_MAX_BIND_VARS path for attack_result_ids and/or objective_sha256. Adding a scale/batching test similar to test_batching_scale.py would help prevent regressions.
| @@ -1483,13 +1646,19 @@ def get_scenario_results( | |||
| for conv_ids in conversation_ids_by_attack.values(): | |||
| all_conversation_ids.extend(conv_ids) | |||
|
|
|||
| # Query all AttackResults in a single batch if there are any | |||
| # Query all AttackResults using batched queries if needed | |||
| if all_conversation_ids: | |||
| # Build condition to query multiple conversation IDs at once | |||
| attack_conditions = [AttackResultEntry.conversation_id.in_(all_conversation_ids)] | |||
| attack_entries: Sequence[AttackResultEntry] = self._query_entries( | |||
| AttackResultEntry, conditions=and_(*attack_conditions) | |||
| ) | |||
| if len(all_conversation_ids) > _SQLITE_MAX_BIND_VARS: | |||
| attack_entries = self._execute_batched_query( | |||
| AttackResultEntry, | |||
| batch_column=AttackResultEntry.conversation_id, | |||
| batch_values=all_conversation_ids, | |||
| ) | |||
| else: | |||
| attack_entries = self._query_entries( | |||
| AttackResultEntry, | |||
| conditions=AttackResultEntry.conversation_id.in_(all_conversation_ids), | |||
| ) | |||
There was a problem hiding this comment.
get_scenario_results now batches when scenario_result_ids or the derived all_conversation_ids exceed _SQLITE_MAX_BIND_VARS, but there are no unit tests covering these batched code paths. Consider adding tests that create enough ScenarioResults / AttackResults to exceed the limit and verify results are complete and correctly associated back to each atomic attack name.
Description
closes #845
Tests and Documentation
Ran test with
Then ran larger test suite with
Documented on code, where needed, as batching stays quite under the hood not sure if it requires changing any documentation or not. But willing to work on documentation part, if I missed something.
Open for upcoming reviews!