Example of a ForceField implemented with JAX#557
Example of a ForceField implemented with JAX#557leobois67 wants to merge 7 commits intosofa-framework:masterfrom
Conversation
|
Note: it requires JAX to run |
|
Sorry, but I get a segfault while running it at the very first timestep. |
|
I believe this example stopped working after this commit 5029c050f68078ab5173d1ffcace69598caa7f8b by @alxbilger. I don’t know exactly how things worked before, but now when I try to fix it, (1) I have to add an extra dimension to enter the correct |
|
@leobois67 you're probably right about my commit. However, I don't understand how it could work before. I suspect that the matrix was not filled with your Python force field. This hypothesis is strenghened by the fact that it is slower now. Is it possible? |
|
@alxbilger I agree that it probably did not work as it was supposed to. The matrix that was passed was not a numpy array but a jax array, with shape (n, 3, n, 3), so I don’t know how it was processed, but I guess it was silently ignored. Also, it is the processing of the matrix that is slow, not its computation: returning a dense matrix full of zeros seems to be just as slow. Is there a way to improve that? |
|
In the code in 5029c05, each |
|
I think I got something that works relatively well, by returning only the non-zero values as you suggested. I left the code that returns the dense matrix for people who don’t have a big sparse jacobian. To give you an idea of the impact on the performances, here are some stats with the simulation of 1000 independent particles:
By "sparse jacobian" I mean returning only the non-zero values; by "optimized" I refer to an optimization I mention in the code, that leverages the knowledge of the sparsity of the jacobian; by "GPU"/"CPU" I refer to the device JAX uses. Also, most of the time spent in |
|
@leobois67 thanks for the benchmark. I don't know JAX enough to answer your question. I can tell that in C++/CUDA, we manipulate a raw pointer, whether the data are on the CPU or on the GPU. We would need a way to communicate the memory location of this data to JAX somehow. Don't know if that's possible |
|
Hello ! I still have a segfault in the buildStifnessMatrix. Could someone else than me try this ? @alxbilger |
|
I ran it on CPU and also have it @bakpaul ########## SIG 11 - SIGSEGV: segfault ##########
sofa::helper::BackTrace::sig(int)
sofa::core::behavior::BaseForceField::buildStiffnessMatrix(sofa::core::behavior::StiffnessMatrix*)
sofa::component::linearsystem::MatrixLinearSystem<sofa::linearalgebra::CompressedRowSparseMatrixMechanical<double, sofa::linearalgebra::CRSMechanicalPolicy>, sofa::linearalgebra::FullVector<double> >::assembleSystem(sofa::core::MechanicalParams const*)::{lambda(sofa::component::linearsystem::MatrixLinearSystem<sofa::linearalgebra::CompressedRowSparseMatrixMechanical<double, sofa::linearalgebra::CRSMechanicalPolicy>, sofa::linearalgebra::FullVector<double> >::IndependentContributors&)#1}::operator()(sofa::component::linearsystem::MatrixLinearSystem<sofa::linearalgebra::CompressedRowSparseMatrixMechanical<double, sofa::linearalgebra::CRSMechanicalPolicy>, sofa::linearalgebra::FullVector<double> >::IndependentContributors&) const
sofa::component::linearsystem::MatrixLinearSystem<sofa::linearalgebra::CompressedRowSparseMatrixMechanical<double, sofa::linearalgebra::CRSMechanicalPolicy>, sofa::linearalgebra::FullVector<double> >::assembleSystem(sofa::core::MechanicalParams const*)
sofa::core::behavior::BaseMatrixLinearSystem::buildSystemMatrix(sofa::core::MechanicalParams const*)
sofa::component::odesolver::backward::EulerImplicitSolver::solve(sofa::core::ExecParams const*, double, sofa::core::TMultiVecId<(sofa::core::VecType)1, (sofa::core::VecAccess)1>, sofa::core::TMultiVecId<(sofa::core::VecType)2, (sofa::core::VecAccess)1>)
sofa::simulation::SolveVisitor::processSolver(sofa::simulation::Node*, sofa::core::behavior::OdeSolver*)
void sofa::simulation::Visitor::for_each<sofa::simulation::SolveVisitor, sofa::simulation::Node, sofa::simulation::NodeSequence<sofa::core::behavior::OdeSolver, false>, sofa::core::behavior::OdeSolver>(sofa::simulation::SolveVisitor*, sofa::simulation::Node*, sofa::simulation::NodeSequence<sofa::core::behavior::OdeSolver, false> const&, void (sofa::simulation::SolveVisitor::*)(sofa::simulation::Node*, sofa::core::behavior::OdeSolver*), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)
sofa::simulation::SolveVisitor::processNodeTopDown(sofa::simulation::Node*)
sofa::simulation::Node::executeVisitorTopDown(sofa::simulation::Visitor*, std::__cxx11::list<sofa::simulation::Node*, std::allocator<sofa::simulation::Node*> >&, std::map<sofa::simulation::Node*, sofa::simulation::Node::StatusStruct, std::less<sofa::simulation::Node*>, std::allocator<std::pair<sofa::simulation::Node* const, sofa::simulation::Node::StatusStruct> > >&, sofa::simulation::Node*)
sofa::simulation::Node::executeVisitorTopDown(sofa::simulation::Visitor*, std::__cxx11::list<sofa::simulation::Node*, std::allocator<sofa::simulation::Node*> >&, std::map<sofa::simulation::Node*, sofa::simulation::Node::StatusStruct, std::less<sofa::simulation::Node*>, std::allocator<std::pair<sofa::simulation::Node* const, sofa::simulation::Node::StatusStruct> > >&, sofa::simulation::Node*)
sofa::simulation::Node::doExecuteVisitor(sofa::simulation::Visitor*, bool)
sofa::simulation::DefaultAnimationLoop::solve(sofa::core::ExecParams const*, double) const
sofa::simulation::DefaultAnimationLoop::animate(sofa::core::ExecParams const*, double) const
sofa::simulation::DefaultAnimationLoop::step(sofa::core::ExecParams const*, double)
sofa::simulation::node::animate(sofa::simulation::Node*, double)
sofaglfw::SofaGLFWBaseGUI::runLoop(unsigned long)
sofaglfw::SofaGLFWGUI::mainLoop()
sofa::gui::common::GUIManager::MainLoop(boost::intrusive_ptr<sofa::simulation::Node>, char const*)
__libc_start_main@leobois67 could you please update us on the status on your machine? |
|
|
||
|
|
||
| # Some configuration for JAX: device and precision | ||
| jax.config.update("jax_default_device", jax.devices("gpu")[0]) # default "gpu" |
There was a problem hiding this comment.
Possibly add a comment for a CPU version :
| jax.config.update("jax_default_device", jax.devices("gpu")[0]) # default "gpu" | |
| # jax.config.update('jax_default_device', jax.devices('cpu')[0]) | |
| jax.config.update("jax_default_device", jax.devices("gpu")[0]) # default "gpu" |
|
I just checked and it still works on my machine, both on CPU and GPU. To be sure that |
|
You can come to me next time you work on this PR, but in case I am not available, here are some other quick suggestions:
|
|
Is that what we are supposed to see? @leobois67 Screencast.from.2026-01-29.18-59-53.webm |
|
No it's not! Each particle is attached to the origin with a spring, so they should be kind of oscillating like a pendulum + oscillating due to the spring expanding and contracting; which is not what I see in the video. Maybe try different options to check whether you get the same behavior or not? |
Illustrates how to leverage JAX autodiff to implement the methods 'addDForce()' and 'addKToMatrix()' automatically.
This example shows a set of particles attached to the origin with a simple spring.
I included a few options for the scene, that might be unnecessary: