Make the expectation functions traceable#4
Make the expectation functions traceable#4MPMPMPMPMPMPMP wants to merge 1 commit intovariPEPS:mainfrom
Conversation
|
Yeah, that's right that this is a problem why we cannot trace over the expectation functions but why not just using |
|
I also tried that, but
taken from https://openxla.org/xla/operation_semantics#conditional this is wrapped in jax.lax.cond |
2aeb4e4 to
723cad9
Compare
TLDR;
The idea of this PR is that we can maybe collect all the changes that are necessary to make the expectation functions traceable. Then we can jit a part of the
__call__method of the expectation function which would improve memory since jax also optimizes that under jit and it would allow us to checkpoint the computation of the expectation since it is very memory intensive to compute the RDM. I know there is also thecheckpoint_nconbut this makes it more flexible.the problematic part was this which is included in all expectation functions:
This construct performs a Python-level loop and uses JAX arrays in a dynamic control context, which prevents JAX from tracing or staging out the function properly.
Despite multiple workarounds, none of them integrated cleanly with JAX’s tracing model or yielded good memory behavior.
My proposal would be to remove all of these and let the user handle the dtype in the model they define. Since the result array is also only not big ~ num gates there should be no problem with the memory
A example from my code
AI SUMMARY
This pull request simplifies the
three_sites.pymodule by removing thereal_resultlogic from the_three_site_triangle_workhorsefunction and its associated callers. This streamlines the computation of expectation values for three-site triangles, ensuring consistent output regardless of whether the gates are Hermitian.Refactoring and code simplification:
real_resultargument from the_three_site_triangle_workhorsefunction signature and its usage in all calling functions, eliminating conditional logic based on gate Hermiticity. [1] [2]real_resultvariable (which checked for Hermitian gates) from all relevantcalc_three_sites_triangle_without_*_multiple_gatesfunctions. [1] [2] [3] [4]_three_site_triangle_workhorseto remove thereal_resultparameter, further simplifying the function interfaces. [1] [2] [3] [4]