Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 103 additions & 30 deletions doc/extensions/ecproofs/ecproofs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# --------------------------------------------------------------
from __future__ import annotations
from typing import Any

import docutils as du

Expand All @@ -14,9 +15,13 @@
import subprocess as subp
import tempfile


# ======================================================================
ROOT = os.path.dirname(__file__)

# ======================================================================
logger = su.logging.getLogger(__name__)

# ======================================================================
class ProofnavNode(du.nodes.General, du.nodes.Element):
@staticmethod
Expand All @@ -37,58 +42,95 @@ def depart_proofnav_node_html(self, node: ProofnavNode):

self.body.append(html)

# ======================================================================
class EasyCryptError(se.SphinxError):
category = "easycrypt"

# ======================================================================
class EasyCrypt:
@staticmethod
def run(cmd, *, location: Any | None = None, warn_only: bool = True):
try:
proc = subp.run(
cmd, check = True, text = True, capture_output = True,
)
logger.debug("Command stdout:\n%s", proc.stdout)
logger.debug("Command stderr:\n%s", proc.stderr)

return True

except subp.CalledProcessError as e:
msg = f"{cmd[0]} exited with code {e.returncode}"

if e.stdout:
logger.debug("stdout:\n%s", e.stdout, location = location)
if e.stderr:
logger.debug("stderr:\n%s", e.stderr, location = location)

logs = [x.split(maxsplit = 1) for x in e.stderr.splitlines()]
logs = [x[1] for x in logs if len(x) == 2 and x[0] == 'E']

for log in logs:
logger.error(log, location = location, type = EasyCryptError.category)

logger.error(msg, location = location, type = EasyCryptError.category)

raise EasyCryptError(msg) from e

# ======================================================================
class EasyCryptProofDirective(su.docutils.SphinxDirective):
TRAP_RE = r'\(\*\s*\$\s*\*\)\s*'

has_content = True

option_spec = {
'title': su.docutils.directives.unchanged,
}

def run(self):
env = self.state.document.settings.env
def find_trap(self, source: str):
location = (self.state.document.current_source, self.lineno)

rawcode = '\n'.join(self.content) + '\n'
if (trap := re.search(self.TRAP_RE, source, re.MULTILINE)) is None:
logger.error(
'Cannot find the trap',
location = location, type = EasyCryptError.category)
raise EasyCryptError

# Find the trap
if (trap := re.search(r'\(\*\s*\$\s*\*\)\s*', rawcode, re.MULTILINE)) is None:
raise se.SphinxError('Cannot find the trap')
code = rawcode[:trap.start()] + rawcode[trap.end():]

# Find the trap sentence number
sentences = [
m.end() - 1
for m in re.finditer(r'\.(\s+|\$)', code)
]
sentence = bisect.bisect_left(sentences, trap.start())
return trap

def run_easycrypt(self, source: str):
location = (self.state.document.current_source, self.lineno)

# Run EasyCrypt and extract the proof trace
with tempfile.TemporaryDirectory(delete = False) as tmpdir:
ecfile = os.path.join(tmpdir, 'input.ec')
ecofile = os.path.join(tmpdir, 'input.eco')

with open(ecfile, 'w') as ecstream:
ecstream.write(code)
subp.check_call(
['easycrypt', 'compile', '-pragmas', 'Proofs:weak', '-trace', ecfile],
stdout = subp.DEVNULL,
stderr = subp.DEVNULL,
ecstream.write(source)

EasyCrypt.run(
['easycrypt', 'compile', '-script', '-pragmas', 'Proofs:weak', '-trace', ecfile],
location = location
)

with open(ecofile) as ecostream:
eco = json.load(ecostream)
return json.load(ecostream)

serial = env.new_serialno("proofnav")
def create_widget(self, code: str, sentence: int, eco: Any):
serial = self.state.document.settings.env.new_serialno("proofnav")
uid = f"proofnav-{serial}"

# Create widget metadata
data = dict()

data["source"] = code
data["sentenceEnds"] = [x["position"] for x in eco["trace"][1:]]
data["sentences"] = [
sentences = [
dict(goals = x["goals"], message = x["messages"])
for x in eco["trace"][1:]
]
data["initialSentence"] = sentence - 1

data = dict(
source = code,
sentenceEnds = [x["position"] for x in eco["trace"][1:]],
sentences = sentences,
initialSentence = sentence - 1,
)

if 'title' in self.options:
data['title'] = self.options['title']
Expand All @@ -97,8 +139,39 @@ def run(self):
node["uid"] = uid
node["json"] = json.dumps(
data, ensure_ascii = False, separators = (",", ":"), indent = 2)

return node

return [node]
def run(self):
try:
rawcode = '\n'.join(self.content) + '\n'

# Find the trap and erase it
trap = self.find_trap(rawcode)
code = rawcode[:trap.start()] + rawcode[trap.end():]

# Find the trap sentence number
sentences = [
m.end() - 1
for m in re.finditer(r'\.(\s+|\$)', code)
]
sentence = bisect.bisect_left(sentences, trap.start())

# Run EasyCrypt and extract the proof trace
eco = self.run_easycrypt(code)

# Create the widget
node = self.create_widget(code, sentence, eco)
return [node]

except EasyCryptError:
self.state.document.settings.env.note_reread()

fallback = du.nodes.literal_block(
"[easycrypt failed]",
"[easycrypt failed]",
)
return [fallback]

# ======================================================================
def on_builder_inited(app: sa.Sphinx):
Expand Down