Skip to content

FEAT implement batching for memory interface#1325

Open
maifeeulasad wants to merge 6 commits intoAzure:mainfrom
maifeeulasad:fix/memory-interface-batching-scale-maifee
Open

FEAT implement batching for memory interface#1325
maifeeulasad wants to merge 6 commits intoAzure:mainfrom
maifeeulasad:fix/memory-interface-batching-scale-maifee

Conversation

@maifeeulasad
Copy link

Description

closes #845

Tests and Documentation

Ran test with

python -m pytest tests/unit/memory/memory_interface/test_batching_scale.py -v

Then ran larger test suite with

python -m pytest tests/unit/**/**/*.py -v

Documented on code, where needed, as batching stays quite under the hood not sure if it requires changing any documentation or not. But willing to work on documentation part, if I missed something.

Open for upcoming reviews!

 - add `_SQLITE_MAX_BIND_VARS` constant to handle SQLite limits
 - implement batching in `get_scores()` for `score_ids` parameter
 - implement batching in `get_message_pieces()` for `prompt_ids`, `original_values`, `converted_values`, `converted_value_sha256`
 - implement batching in `get_attack_results()` for `attack_result_ids`, `objective_sha256 parameters`
 - implement batching in `get_scenario_results()` for `scenario_result_ids`
 - refactor necessary filter conditions across batched queries
 - handle empty list edge cases
)

```
python -m pytest tests/unit/memory/memory_interface/test_batching_scale.py -v
```
@romanlutz romanlutz changed the title implement batching for memory interface FEAT implement batching for memory interface Jan 25, 2026
@maifeeulasad maifeeulasad marked this pull request as draft January 26, 2026 07:08
 - independent batching of all parameters
 - extracted batch in condition
 - helper functions
…re#845)

 - tests focusing on independent batching of all parameters
@maifeeulasad maifeeulasad marked this pull request as ready for review January 26, 2026 11:26
@romanlutz romanlutz requested a review from Copilot February 6, 2026 21:39
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR aims to address scaling limits in the PyRIT MemoryInterface by introducing “batched” filtering logic for large IN (...) queries, along with unit tests intended to validate behavior when many IDs/values are passed.

Changes:

  • Added _SQLITE_MAX_BIND_VARS and _batched_in_condition helper in MemoryInterface and applied it to several query filters.
  • Updated multiple retrieval methods (e.g., scores, message pieces, attack results, scenario results) to use the new helper for IN filtering.
  • Added a new unit test module to exercise “batching” behavior and large-input queries against the SQLite-backed memory instance.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
pyrit/memory/memory_interface.py Introduces a helper intended to avoid bind-variable limits and applies it across several query paths.
tests/unit/memory/memory_interface/test_batching_scale.py Adds tests intended to validate that large ID/value inputs don’t break memory queries.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

Comment on lines +53 to +55
# ref: https://www.sqlite.org/limits.html
# Lowest default maximum is 999, intentionally setting it to half
_SQLITE_MAX_BIND_VARS = 500
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_SQLITE_MAX_BIND_VARS is defined in the base MemoryInterface, but it’s SQLite-specific and is now used to batch queries for Azure SQL as well. This can be confusing and makes it harder to tune batching per-backend; consider making the limit a backend-specific attribute (e.g., on SQLiteMemory / AzureSQLMemory) or deriving it from the SQLAlchemy dialect, with a safe default fallback.

Copilot uses AI. Check for mistakes.
Comment on lines +445 to 457
# Handle score_ids with batched queries if needed
if score_ids:
entries = self._execute_batched_query(
ScoreEntry,
batch_column=ScoreEntry.id,
batch_values=list(score_ids),
other_conditions=conditions,
)
return [entry.get_score() for entry in entries]

# No score_ids specified - use regular query
if not conditions:
return []
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_scores treats score_ids=[] the same as “no score_ids filter” (because the check is if score_ids:). If callers pass an explicit empty list alongside other filters, the expected intersection is typically empty; consider mirroring the get_attack_results/get_scenario_results pattern by returning [] when score_ids is not None and len(score_ids) == 0.

Copilot uses AI. Check for mistakes.
Comment on lines +1390 to 1427
# Identify list parameters and whether they need batching
list_params: list[tuple[InstrumentedAttribute[Any], Sequence[Any], str]] = []
if attack_result_ids:
list_params.append((AttackResultEntry.id, list(attack_result_ids), "id"))
if objective_sha256:
list_params.append((AttackResultEntry.objective_sha256, list(objective_sha256), "objective_sha256"))

# If no list params, execute simple query
if not list_params:
entries: Sequence[AttackResultEntry] = self._query_entries(
AttackResultEntry, conditions=and_(*conditions) if conditions else None
)
return [entry.get_attack_result() for entry in entries]

# Find which list params need batching
large_params = [(col, vals, name) for col, vals, name in list_params if len(vals) > _SQLITE_MAX_BIND_VARS]
small_params = [(col, vals, name) for col, vals, name in list_params if len(vals) <= _SQLITE_MAX_BIND_VARS]

# Add small list params to conditions
for col, vals, _ in small_params:
conditions.append(col.in_(vals))

# If no large params, execute simple query
if not large_params:
entries = self._query_entries(AttackResultEntry, conditions=and_(*conditions) if conditions else None)
return [entry.get_attack_result() for entry in entries]

# Batch on the first large parameter
batch_col, batch_vals, _ = large_params[0]
other_large_params = large_params[1:]

# Execute batched query
entries = self._execute_batched_query(
AttackResultEntry,
batch_column=batch_col,
batch_values=batch_vals,
other_conditions=conditions,
)
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Batching logic was added to get_attack_results (including multi-list-parameter handling and Python-side filtering for additional large params), but there are no unit tests exercising the > _SQLITE_MAX_BIND_VARS path for attack_result_ids and/or objective_sha256. Adding a scale/batching test similar to test_batching_scale.py would help prevent regressions.

Copilot uses AI. Check for mistakes.
Comment on lines 1622 to 1661
@@ -1483,13 +1646,19 @@ def get_scenario_results(
for conv_ids in conversation_ids_by_attack.values():
all_conversation_ids.extend(conv_ids)

# Query all AttackResults in a single batch if there are any
# Query all AttackResults using batched queries if needed
if all_conversation_ids:
# Build condition to query multiple conversation IDs at once
attack_conditions = [AttackResultEntry.conversation_id.in_(all_conversation_ids)]
attack_entries: Sequence[AttackResultEntry] = self._query_entries(
AttackResultEntry, conditions=and_(*attack_conditions)
)
if len(all_conversation_ids) > _SQLITE_MAX_BIND_VARS:
attack_entries = self._execute_batched_query(
AttackResultEntry,
batch_column=AttackResultEntry.conversation_id,
batch_values=all_conversation_ids,
)
else:
attack_entries = self._query_entries(
AttackResultEntry,
conditions=AttackResultEntry.conversation_id.in_(all_conversation_ids),
)
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_scenario_results now batches when scenario_result_ids or the derived all_conversation_ids exceed _SQLITE_MAX_BIND_VARS, but there are no unit tests covering these batched code paths. Consider adding tests that create enough ScenarioResults / AttackResults to exceed the limit and verify results are complete and correctly associated back to each atomic attack name.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG get_scores_by_memory_labels and a few other memory methods do not scale

2 participants