diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..8d421df --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,9 @@ +{ + "permissions": { + "allow": [ + "Bash(pytest:*)", + "Bash(python -m pytest:*)", + "Bash(python:*)" + ] + } +} diff --git a/config.json b/config.json index d13fa81..a5a450b 100644 --- a/config.json +++ b/config.json @@ -1,3 +1,3 @@ { - "EvaluationFunctionName": "" + "EvaluationFunctionName": "complexityEval" } diff --git a/evaluation_function/analyzer/__init__.py b/evaluation_function/analyzer/__init__.py new file mode 100644 index 0000000..48c06d9 --- /dev/null +++ b/evaluation_function/analyzer/__init__.py @@ -0,0 +1,15 @@ +""" +Complexity Analyzer module. + +This module provides analysis of pseudocode to determine time and space complexity. +""" + +from .complexity_analyzer import ComplexityAnalyzer, AnalysisResult +from .feedback_generator import FeedbackGenerator, DetailedFeedback + +__all__ = [ + "ComplexityAnalyzer", + "AnalysisResult", + "FeedbackGenerator", + "DetailedFeedback", +] diff --git a/evaluation_function/analyzer/complexity_analyzer.py b/evaluation_function/analyzer/complexity_analyzer.py new file mode 100644 index 0000000..1ebd9ce --- /dev/null +++ b/evaluation_function/analyzer/complexity_analyzer.py @@ -0,0 +1,713 @@ +""" +Complexity Analyzer - Analyzes pseudocode AST to determine time complexity. + +This module provides the core analysis logic for determining algorithm complexity +from parsed pseudocode, with detailed tracking of complexity factors. +""" + +from typing import List, Optional, Dict, Any, Tuple +from dataclasses import dataclass, field +import re +import math + +from ..schemas.complexity import ( + ComplexityClass, + LoopComplexity, + RecursionComplexity, + TimeComplexity, + SpaceComplexity, + ComplexityResult, + ComplexityFactor, +) +from ..schemas.ast_nodes import ( + ProgramNode, + FunctionNode, + BlockNode, + LoopNode, + ConditionalNode, + LoopType, +) + + +@dataclass +class LoopInfo: + """Information about a detected loop.""" + loop_type: str + iterator: Optional[str] + start_bound: Optional[str] + end_bound: Optional[str] + step: Optional[str] + iterations: str + complexity: ComplexityClass + nesting_level: int + line_number: int = 0 + has_early_exit: bool = False + nested_loops: List["LoopInfo"] = field(default_factory=list) + + def get_description(self) -> str: + """Get human-readable description of this loop.""" + if self.loop_type == "for" and self.iterator: + if self.start_bound and self.end_bound: + return f"FOR loop ({self.iterator} from {self.start_bound} to {self.end_bound})" + return f"FOR loop with iterator {self.iterator}" + elif self.loop_type == "for_each": + return f"FOR-EACH loop iterating over collection" + elif self.loop_type == "while": + return f"WHILE loop" + elif self.loop_type == "repeat": + return f"REPEAT-UNTIL loop" + return f"{self.loop_type.upper()} loop" + + +@dataclass +class RecursionInfo: + """Information about detected recursion.""" + function_name: str + num_recursive_calls: int + reduction_pattern: str # "n-1", "n/2", etc. + branching_factor: int + work_per_call: ComplexityClass + complexity: ComplexityClass + recurrence: str # e.g., "T(n) = 2T(n/2) + O(n)" + + def get_description(self) -> str: + """Get human-readable description of this recursion.""" + if self.branching_factor == 1: + return f"Linear recursion in {self.function_name}()" + elif self.branching_factor == 2: + if "n/2" in self.reduction_pattern: + return f"Divide-and-conquer recursion in {self.function_name}()" + else: + return f"Binary recursion in {self.function_name}()" + return f"Multiple recursion ({self.branching_factor} calls) in {self.function_name}()" + + +@dataclass +class AnalysisResult: + """Complete result of complexity analysis.""" + time_complexity: ComplexityClass + space_complexity: ComplexityClass + loops: List[LoopInfo] + recursion: Optional[RecursionInfo] + max_nesting_depth: int + confidence: float + factors: List[ComplexityFactor] + raw_code: str = "" + + def get_complexity_string(self) -> str: + """Get the complexity as a string.""" + return self.time_complexity.value + + +class ComplexityAnalyzer: + """ + Analyzes pseudocode to determine time and space complexity. + + This analyzer works with both parsed AST and raw code fallback, + providing detailed analysis of loops, recursion, and other factors. + """ + + def __init__(self): + self.loops: List[LoopInfo] = [] + self.recursion_info: Optional[RecursionInfo] = None + self.factors: List[ComplexityFactor] = [] + self.current_function: Optional[str] = None + self.function_calls: Dict[str, List[str]] = {} + + def analyze(self, code: str, ast: Optional[ProgramNode] = None) -> AnalysisResult: + """ + Analyze code complexity. + + Args: + code: The pseudocode string + ast: Optional parsed AST (uses pattern matching if not provided) + + Returns: + AnalysisResult with complexity information + """ + self._reset() + + if ast: + return self._analyze_ast(ast, code) + else: + return self._analyze_patterns(code) + + def _reset(self): + """Reset analyzer state.""" + self.loops = [] + self.recursion_info = None + self.factors = [] + self.current_function = None + self.function_calls = {} + + def _analyze_ast(self, ast: ProgramNode, code: str) -> AnalysisResult: + """Analyze from parsed AST.""" + # Analyze functions + for func in ast.functions: + self.current_function = func.name + self.function_calls[func.name] = [] + + if func.body: + self._analyze_block(func.body, nesting_level=0) + + # Check for recursion using pattern matching on code + # Look for recursive calls in the code + if code: + call_pattern = rf'\b{func.name}\s*\(' + calls = re.findall(call_pattern, code, re.IGNORECASE) + if len(calls) > 1: # More than just the definition + self._detect_recursion_in_function(func, code) + + # Analyze global statements + if ast.global_statements: + self._analyze_block(ast.global_statements, nesting_level=0) + + return self._compute_result(code) + + def _analyze_block(self, block: BlockNode, nesting_level: int): + """Analyze a block of statements (top-level, adds to self.loops).""" + for stmt in block.statements: + if isinstance(stmt, LoopNode): + loop_info = self._analyze_loop(stmt, nesting_level) + self.loops.append(loop_info) + elif isinstance(stmt, ConditionalNode): + # For top-level conditionals, add any loops they contain to self.loops + loops_from_cond = self._analyze_conditional(stmt, nesting_level) + self.loops.extend(loops_from_cond) + + def _analyze_loop(self, loop: LoopNode, nesting_level: int) -> LoopInfo: + """Analyze a loop node.""" + loop_type = loop.loop_type.value if loop.loop_type else "for" + iterator = loop.iterator.name if loop.iterator else None + + # Determine iterations + iterations, complexity = self._estimate_loop_iterations(loop) + + # Get bounds + start_bound = self._expr_to_string(loop.start) + end_bound = self._expr_to_string(loop.end) + step = self._expr_to_string(loop.step) if loop.step else "1" + + loop_info = LoopInfo( + loop_type=loop_type, + iterator=iterator, + start_bound=start_bound, + end_bound=end_bound, + step=step, + iterations=iterations, + complexity=complexity, + nesting_level=nesting_level, + line_number=loop.location.line if loop.location else 0 + ) + + # Analyze nested content + if loop.body: + for stmt in loop.body.statements: + if isinstance(stmt, LoopNode): + nested_info = self._analyze_loop(stmt, nesting_level + 1) + loop_info.nested_loops.append(nested_info) + elif isinstance(stmt, ConditionalNode): + # Loops inside conditionals should be counted as nested + nested_from_cond = self._analyze_conditional(stmt, nesting_level + 1) + loop_info.nested_loops.extend(nested_from_cond) + + return loop_info + + def _analyze_conditional(self, cond: ConditionalNode, nesting_level: int) -> List[LoopInfo]: + """Analyze conditional branches for loops. Returns loops found for nesting.""" + nested_loops = [] + if cond.then_branch: + nested_loops.extend(self._extract_loops_from_block(cond.then_branch, nesting_level)) + if cond.else_branch: + nested_loops.extend(self._extract_loops_from_block(cond.else_branch, nesting_level)) + return nested_loops + + def _extract_loops_from_block(self, block: BlockNode, nesting_level: int) -> List[LoopInfo]: + """Extract loops from a block without adding to self.loops. Used for nested contexts.""" + loops = [] + for stmt in block.statements: + if isinstance(stmt, LoopNode): + loop_info = self._analyze_loop(stmt, nesting_level) + loops.append(loop_info) + elif isinstance(stmt, ConditionalNode): + loops.extend(self._analyze_conditional(stmt, nesting_level)) + return loops + + def _estimate_loop_iterations(self, loop: LoopNode) -> Tuple[str, ComplexityClass]: + """Estimate the number of iterations for a loop.""" + if loop.estimated_iterations: + iterations = loop.estimated_iterations + elif loop.loop_type == LoopType.FOR: + iterations = self._estimate_for_iterations(loop) + elif loop.loop_type == LoopType.FOR_EACH: + iterations = "n" # Collection size + else: + iterations = "n" # Default for while/repeat + + # Convert iterations to complexity class + complexity = self._iterations_to_complexity(iterations) + + return iterations, complexity + + def _estimate_for_iterations(self, loop: LoopNode) -> str: + """Estimate FOR loop iterations from bounds.""" + start = self._expr_to_string(loop.start) + end = self._expr_to_string(loop.end) + + if not start or not end: + return "n" + + start_lower = start.lower() + end_lower = end.lower() + + # Common patterns + if start_lower in ("0", "1"): + if end_lower in ("n", "len", "length", "size", "count"): + return "n" + if "n-" in end_lower or "n -" in end_lower: + return "n" + if "n/2" in end_lower or "n / 2" in end_lower: + return "n/2" + if "log" in end_lower: + return "log(n)" + if end_lower.startswith("sqrt") or "√" in end_lower: + return "√n" + + # Check if both are constants + try: + s = int(start) + e = int(end) + return str(e - s + 1) + except (ValueError, TypeError): + pass + + # Default to n + return "n" + + def _iterations_to_complexity(self, iterations: str) -> ComplexityClass: + """Convert iteration count to complexity class.""" + iterations_lower = iterations.lower().replace(" ", "") + + if iterations_lower in ("1", "2", "3", "4", "5", "10", "100"): + return ComplexityClass.CONSTANT + if "log" in iterations_lower: + return ComplexityClass.LOGARITHMIC + if "sqrt" in iterations_lower or "√" in iterations_lower: + return ComplexityClass.SQRT + if iterations_lower in ("n", "n-1", "n+1", "len", "length", "size"): + return ComplexityClass.LINEAR + if "n/2" in iterations_lower: + return ComplexityClass.LINEAR # Still O(n) + + return ComplexityClass.LINEAR # Default + + def _expr_to_string(self, expr) -> Optional[str]: + """Convert expression node to string.""" + if expr is None: + return None + if hasattr(expr, 'name'): + return expr.name + if hasattr(expr, 'value'): + return str(expr.value) + return str(expr) + + def _detect_recursion_in_function(self, func: FunctionNode, code: str): + """Detect recursion pattern in a function.""" + func_name = func.name + + # Count recursive calls in code + call_pattern = rf'\b{func_name}\s*\(' + calls = re.findall(call_pattern, code, re.IGNORECASE) + num_calls = len(calls) - 1 # Subtract definition + + if num_calls <= 0: + return + + # Analyze reduction pattern + reduction, branching = self._analyze_recursive_calls(func_name, code) + + # Determine complexity + complexity = self._compute_recursion_complexity(branching, reduction) + + # Build recurrence relation + recurrence = self._build_recurrence(branching, reduction) + + self.recursion_info = RecursionInfo( + function_name=func_name, + num_recursive_calls=num_calls, + reduction_pattern=reduction, + branching_factor=branching, + work_per_call=ComplexityClass.CONSTANT, + complexity=complexity, + recurrence=recurrence + ) + + def _analyze_recursive_calls(self, func_name: str, code: str) -> Tuple[str, int]: + """Analyze recursive call patterns.""" + code_lower = code.lower() + + # Count recursive calls, distinguishing between: + # - Calls in mutually exclusive branches (if-else with RETURN) -> branching = 1 + # - Calls that all execute (sequential, or in same expression) -> branching = count + lines = code.split('\n') + total_recursive_calls = 0 + max_calls_per_line = 0 + calls_with_return = 0 + calls_without_return = 0 + + for line in lines: + line_lower = line.lower().strip() + # Skip function definition line + if re.search(rf'(function|def|procedure|algorithm)\s+{func_name}\s*\(', line, re.IGNORECASE): + continue + calls = len(re.findall(rf'\b{func_name}\s*\(', line, re.IGNORECASE)) + if calls > 0: + total_recursive_calls += calls + max_calls_per_line = max(max_calls_per_line, calls) + # Check if this line has RETURN before the call + if re.search(r'\breturn\b', line, re.IGNORECASE): + calls_with_return += calls + else: + calls_without_return += calls + + # Determine branching factor: + # - If all calls are in return statements (like binary search), likely mutually exclusive -> branching = 1 + # - If calls are in same expression (like fibonacci), use max per line + # - If calls are sequential without return (like merge sort), all execute -> use total + if max_calls_per_line >= 2: + # Multiple calls on same line (e.g., fib(n-1) + fib(n-2)) + branching = max_calls_per_line + elif calls_without_return >= 2: + # Multiple sequential calls without return (merge sort pattern) + branching = calls_without_return + elif calls_with_return > 0 and calls_without_return == 0: + # All calls are in return statements (binary search pattern - mutually exclusive) + branching = 1 + else: + branching = max(1, total_recursive_calls) + + # Detect reduction pattern + # Use word boundaries to avoid matching parts of other words (e.g., "return -1") + patterns = [ + (r'\bn\s*-\s*1\b', 'n-1'), + (r'\bn\s*-\s*2\b', 'n-1'), # Still linear reduction + (r'\bn\s*/\s*2\b', 'n/2'), + (r'\bn\s*//\s*2\b', 'n/2'), + (r'\bmid\b', 'n/2'), + (r'\blow\b.*\bhigh\b', 'n/2'), + ] + + reduction = 'n-1' # Default + for pattern, result in patterns: + if re.search(pattern, code_lower): + reduction = result + break + + return reduction, branching + + def _compute_recursion_complexity(self, branching: int, reduction: str) -> ComplexityClass: + """Compute recursion complexity using Master Theorem.""" + if "n/2" in reduction: + # Divide and conquer + if branching == 1: + return ComplexityClass.LOGARITHMIC # Binary search + elif branching == 2: + return ComplexityClass.LINEARITHMIC # Merge sort + else: + return ComplexityClass.POLYNOMIAL + else: + # Linear reduction (n-1) + if branching == 1: + return ComplexityClass.LINEAR # Simple recursion + elif branching == 2: + return ComplexityClass.EXPONENTIAL # Fibonacci-like + else: + return ComplexityClass.EXPONENTIAL + + def _build_recurrence(self, branching: int, reduction: str) -> str: + """Build recurrence relation string.""" + if "n/2" in reduction: + return f"T(n) = {branching}T(n/2) + O(1)" + else: + return f"T(n) = {branching}T(n-1) + O(1)" + + def _analyze_patterns(self, code: str) -> AnalysisResult: + """Analyze code using pattern matching (fallback).""" + lines = code.split('\n') + + # Detect function definitions + func_match = re.search(r'(function|algorithm|def|procedure)\s+(\w+)', code, re.IGNORECASE) + if func_match: + self.current_function = func_match.group(2) + + # Analyze each line for loops + self._detect_loops_from_code(lines) + + # Detect recursion + if self.current_function: + call_pattern = rf'\b{self.current_function}\s*\(' + calls = re.findall(call_pattern, code, re.IGNORECASE) + if len(calls) > 1: + reduction, branching = self._analyze_recursive_calls(self.current_function, code) + complexity = self._compute_recursion_complexity(branching, reduction) + self.recursion_info = RecursionInfo( + function_name=self.current_function, + num_recursive_calls=len(calls) - 1, + reduction_pattern=reduction, + branching_factor=branching, + work_per_call=ComplexityClass.CONSTANT, + complexity=complexity, + recurrence=self._build_recurrence(branching, reduction) + ) + + return self._compute_result(code) + + def _detect_loops_from_code(self, lines: List[str]): + """Detect loops from raw code lines.""" + indent_stack: List[Tuple[int, LoopInfo]] = [] + + for line_num, line in enumerate(lines, 1): + stripped = line.strip().lower() + if not stripped: + continue + + indent = len(line) - len(line.lstrip()) + + # Pop loops that have ended + while indent_stack and indent <= indent_stack[-1][0]: + indent_stack.pop() + + nesting_level = len(indent_stack) + + # Check for FOR loop + for_match = re.match(r'for\s+(\w+)\s*[=:]\s*(\w+)\s+to\s+(\w+)', stripped) + if for_match: + loop_info = self._create_loop_info_from_match( + "for", for_match, nesting_level, line_num + ) + if indent_stack: + indent_stack[-1][1].nested_loops.append(loop_info) + else: + self.loops.append(loop_info) + indent_stack.append((indent, loop_info)) + continue + + # Check for FOR EACH loop + foreach_match = re.match(r'for\s+(?:each\s+)?(\w+)\s+in\s+(\w+)', stripped) + if foreach_match: + loop_info = LoopInfo( + loop_type="for_each", + iterator=foreach_match.group(1), + start_bound=None, + end_bound=foreach_match.group(2), + step=None, + iterations="n", + complexity=ComplexityClass.LINEAR, + nesting_level=nesting_level, + line_number=line_num + ) + if indent_stack: + indent_stack[-1][1].nested_loops.append(loop_info) + else: + self.loops.append(loop_info) + indent_stack.append((indent, loop_info)) + continue + + # Check for WHILE loop + if stripped.startswith('while ') or stripped.startswith('while('): + loop_info = LoopInfo( + loop_type="while", + iterator=None, + start_bound=None, + end_bound=None, + step=None, + iterations=self._estimate_while_iterations(stripped), + complexity=ComplexityClass.LINEAR, + nesting_level=nesting_level, + line_number=line_num + ) + + # Check for logarithmic pattern + if any(p in stripped for p in ['/2', '//2', '* 2', '*2', 'log']): + loop_info.iterations = "log(n)" + loop_info.complexity = ComplexityClass.LOGARITHMIC + + if indent_stack: + indent_stack[-1][1].nested_loops.append(loop_info) + else: + self.loops.append(loop_info) + indent_stack.append((indent, loop_info)) + continue + + # Check for REPEAT loop + if stripped.startswith('repeat'): + loop_info = LoopInfo( + loop_type="repeat", + iterator=None, + start_bound=None, + end_bound=None, + step=None, + iterations="n", + complexity=ComplexityClass.LINEAR, + nesting_level=nesting_level, + line_number=line_num + ) + if indent_stack: + indent_stack[-1][1].nested_loops.append(loop_info) + else: + self.loops.append(loop_info) + indent_stack.append((indent, loop_info)) + + def _create_loop_info_from_match(self, loop_type: str, match, + nesting_level: int, line_num: int) -> LoopInfo: + """Create LoopInfo from regex match.""" + iterator = match.group(1) + start = match.group(2) + end = match.group(3) + + # Estimate iterations + iterations = self._estimate_iterations_from_bounds(start, end) + complexity = self._iterations_to_complexity(iterations) + + return LoopInfo( + loop_type=loop_type, + iterator=iterator, + start_bound=start, + end_bound=end, + step="1", + iterations=iterations, + complexity=complexity, + nesting_level=nesting_level, + line_number=line_num + ) + + def _estimate_iterations_from_bounds(self, start: str, end: str) -> str: + """Estimate iterations from loop bounds.""" + start_lower = start.lower() + end_lower = end.lower() + + if start_lower in ("0", "1"): + if end_lower in ("n", "len", "length", "size", "count"): + return "n" + if "n-" in end_lower: + return "n" + if "n/2" in end_lower: + return "n/2" + if "log" in end_lower: + return "log(n)" + + try: + s = int(start) + e = int(end) + return str(e - s + 1) + except (ValueError, TypeError): + pass + + return "n" + + def _estimate_while_iterations(self, condition: str) -> str: + """Estimate while loop iterations from condition.""" + if '/2' in condition or '//2' in condition or '* 2' in condition: + return "log(n)" + return "n" + + def _compute_result(self, code: str) -> AnalysisResult: + """Compute final analysis result.""" + time_complexity = ComplexityClass.CONSTANT + max_nesting = 0 + + # Compute from loops + for loop in self.loops: + loop_complexity = self._compute_total_loop_complexity(loop) + if ComplexityClass.compare(loop_complexity, time_complexity) > 0: + time_complexity = loop_complexity + max_nesting = max(max_nesting, self._get_max_nesting(loop)) + + # Consider recursion + if self.recursion_info: + rec_complexity = self.recursion_info.complexity + if ComplexityClass.compare(rec_complexity, time_complexity) > 0: + time_complexity = rec_complexity + + # Compute space complexity + space_complexity = self._compute_space_complexity() + + # Build factors list + factors = self._build_factors() + + # Confidence based on analysis quality + confidence = self._compute_confidence() + + return AnalysisResult( + time_complexity=time_complexity, + space_complexity=space_complexity, + loops=self.loops, + recursion=self.recursion_info, + max_nesting_depth=max_nesting, + confidence=confidence, + factors=factors, + raw_code=code + ) + + def _compute_total_loop_complexity(self, loop: LoopInfo) -> ComplexityClass: + """Compute total complexity of a loop including nested loops.""" + result = loop.complexity + + for nested in loop.nested_loops: + nested_complexity = self._compute_total_loop_complexity(nested) + result = ComplexityClass.multiply(result, nested_complexity) + + return result + + def _get_max_nesting(self, loop: LoopInfo) -> int: + """Get maximum nesting depth.""" + if not loop.nested_loops: + return loop.nesting_level + 1 + return max(self._get_max_nesting(n) for n in loop.nested_loops) + + def _compute_space_complexity(self) -> ComplexityClass: + """Compute space complexity.""" + if self.recursion_info: + # Recursive = stack depth + if "n/2" in self.recursion_info.reduction_pattern: + return ComplexityClass.LOGARITHMIC + else: + return ComplexityClass.LINEAR + + return ComplexityClass.CONSTANT + + def _build_factors(self) -> List[ComplexityFactor]: + """Build list of complexity factors.""" + factors = [] + + for loop in self.loops: + factors.append(ComplexityFactor( + source=loop.get_description(), + factor_type="loop", + complexity=loop.complexity, + iterations=loop.iterations, + nesting_level=loop.nesting_level, + location=f"line {loop.line_number}" if loop.line_number else None + )) + + if self.recursion_info: + factors.append(ComplexityFactor( + source=self.recursion_info.get_description(), + factor_type="recursion", + complexity=self.recursion_info.complexity, + iterations=self.recursion_info.recurrence + )) + + return factors + + def _compute_confidence(self) -> float: + """Compute confidence in the analysis.""" + confidence = 0.8 # Base confidence + + if self.loops or self.recursion_info: + confidence = 0.9 + + if self.recursion_info and self.loops: + confidence = 0.85 # Mixed analysis is harder + + return confidence diff --git a/evaluation_function/analyzer/feedback_generator.py b/evaluation_function/analyzer/feedback_generator.py new file mode 100644 index 0000000..817ac51 --- /dev/null +++ b/evaluation_function/analyzer/feedback_generator.py @@ -0,0 +1,591 @@ +""" +Feedback Generator - Generates detailed human-readable feedback for complexity analysis. + +This module provides clear, educational feedback explaining: +- What complexity was detected and why +- How loops and recursion contribute to complexity +- Step-by-step breakdown of the analysis +- Suggestions for improvement +""" + +from typing import List, Optional, Dict, Any +from dataclasses import dataclass, field +from enum import Enum + +from ..schemas.complexity import ComplexityClass +from .complexity_analyzer import AnalysisResult, LoopInfo, RecursionInfo + + +class FeedbackLevel(str, Enum): + """Level of detail for feedback.""" + BRIEF = "brief" # Just the result + STANDARD = "standard" # Result with explanation + DETAILED = "detailed" # Full breakdown with examples + + +@dataclass +class FeedbackSection: + """A section of feedback.""" + title: str + content: str + importance: str = "info" # "info", "warning", "success", "error" + + +@dataclass +class DetailedFeedback: + """Complete feedback for complexity analysis.""" + # Summary + summary: str + complexity_result: str + + # Breakdown sections + sections: List[FeedbackSection] = field(default_factory=list) + + # Quick facts + loop_count: int = 0 + max_nesting: int = 0 + has_recursion: bool = False + + # Educational content + complexity_explanation: str = "" + real_world_example: str = "" + + # Suggestions + suggestions: List[str] = field(default_factory=list) + + # Confidence + confidence_note: str = "" + + def to_string(self, level: FeedbackLevel = FeedbackLevel.STANDARD) -> str: + """Convert feedback to formatted string.""" + lines = [] + + # Header + lines.append("=" * 60) + lines.append("COMPLEXITY ANALYSIS RESULT") + lines.append("=" * 60) + lines.append("") + + # Summary + lines.append(f"Time Complexity: {self.complexity_result}") + lines.append("") + lines.append(self.summary) + lines.append("") + + if level == FeedbackLevel.BRIEF: + return "\n".join(lines) + + # Sections + for section in self.sections: + lines.append("-" * 40) + lines.append(f"[{section.importance.upper()}] {section.title}") + lines.append("-" * 40) + lines.append(section.content) + lines.append("") + + # Complexity explanation + if self.complexity_explanation: + lines.append("-" * 40) + lines.append("What does this mean?") + lines.append("-" * 40) + lines.append(self.complexity_explanation) + lines.append("") + + if level == FeedbackLevel.STANDARD: + # Add suggestions + if self.suggestions: + lines.append("-" * 40) + lines.append("Suggestions") + lines.append("-" * 40) + for i, suggestion in enumerate(self.suggestions, 1): + lines.append(f" {i}. {suggestion}") + lines.append("") + return "\n".join(lines) + + # Detailed level - add everything + if self.real_world_example: + lines.append("-" * 40) + lines.append("Real-World Example") + lines.append("-" * 40) + lines.append(self.real_world_example) + lines.append("") + + if self.suggestions: + lines.append("-" * 40) + lines.append("Optimization Suggestions") + lines.append("-" * 40) + for i, suggestion in enumerate(self.suggestions, 1): + lines.append(f" {i}. {suggestion}") + lines.append("") + + if self.confidence_note: + lines.append("-" * 40) + lines.append("Analysis Confidence") + lines.append("-" * 40) + lines.append(self.confidence_note) + + return "\n".join(lines) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for JSON serialization.""" + return { + "summary": self.summary, + "complexity": self.complexity_result, + "sections": [ + {"title": s.title, "content": s.content, "importance": s.importance} + for s in self.sections + ], + "stats": { + "loop_count": self.loop_count, + "max_nesting": self.max_nesting, + "has_recursion": self.has_recursion + }, + "explanation": self.complexity_explanation, + "real_world_example": self.real_world_example, + "suggestions": self.suggestions, + "confidence": self.confidence_note + } + + +class FeedbackGenerator: + """ + Generates detailed, human-readable feedback for complexity analysis results. + + The feedback is designed to be educational and help students understand: + - Why their code has a particular complexity + - How different constructs contribute to complexity + - How to improve their algorithms + """ + + # Complexity descriptions and examples + COMPLEXITY_INFO = { + ComplexityClass.CONSTANT: { + "name": "Constant Time", + "symbol": "O(1)", + "description": "The algorithm takes the same amount of time regardless of input size.", + "example": "Accessing an array element by index, checking if a number is even/odd.", + "growth": "No growth - always the same time." + }, + ComplexityClass.LOGARITHMIC: { + "name": "Logarithmic Time", + "symbol": "O(log n)", + "description": "The algorithm reduces the problem size by half (or a constant factor) at each step.", + "example": "Binary search - each comparison eliminates half the remaining elements.", + "growth": "Very slow growth. Doubling n adds only one extra step." + }, + ComplexityClass.SQRT: { + "name": "Square Root Time", + "symbol": "O(√n)", + "description": "The algorithm's time grows with the square root of the input size.", + "example": "Checking if a number is prime by testing divisors up to √n.", + "growth": "Slow growth. For n=1,000,000, only ~1,000 operations." + }, + ComplexityClass.LINEAR: { + "name": "Linear Time", + "symbol": "O(n)", + "description": "The algorithm examines each element once. Time grows proportionally with input size.", + "example": "Finding the maximum element in an unsorted array - must check every element.", + "growth": "Direct growth. Doubling n doubles the time." + }, + ComplexityClass.LINEARITHMIC: { + "name": "Linearithmic Time", + "symbol": "O(n log n)", + "description": "Common in efficient sorting algorithms that use divide-and-conquer.", + "example": "Merge sort, quicksort (average case), heapsort.", + "growth": "Slightly faster than linear growth. Very efficient for sorting." + }, + ComplexityClass.QUADRATIC: { + "name": "Quadratic Time", + "symbol": "O(n²)", + "description": "Usually caused by nested loops where both iterate over the input.", + "example": "Bubble sort, selection sort, comparing all pairs of elements.", + "growth": "Doubling n quadruples the time. Becomes slow for large inputs." + }, + ComplexityClass.CUBIC: { + "name": "Cubic Time", + "symbol": "O(n³)", + "description": "Often seen with triple nested loops, common in matrix operations.", + "example": "Standard matrix multiplication, Floyd-Warshall algorithm.", + "growth": "Doubling n increases time by 8x. Impractical for large n." + }, + ComplexityClass.POLYNOMIAL: { + "name": "Polynomial Time", + "symbol": "O(n^k)", + "description": "Time grows as n raised to some power k > 3.", + "example": "Some dynamic programming solutions, brute-force algorithms.", + "growth": "Very fast growth. Only practical for small inputs." + }, + ComplexityClass.EXPONENTIAL: { + "name": "Exponential Time", + "symbol": "O(2^n)", + "description": "The algorithm doubles its work for each additional input element. Often from recursive solutions that branch multiple times.", + "example": "Naive recursive Fibonacci, generating all subsets, brute-force combinatorial problems.", + "growth": "Extremely fast growth! Adding one element doubles the time." + }, + ComplexityClass.FACTORIAL: { + "name": "Factorial Time", + "symbol": "O(n!)", + "description": "The worst common complexity class. Generates all permutations.", + "example": "Brute-force traveling salesman, generating all permutations.", + "growth": "Astronomically fast growth. n=20 has more operations than atoms in the universe!" + } + } + + def __init__(self): + pass + + def generate(self, result: AnalysisResult, + level: FeedbackLevel = FeedbackLevel.STANDARD) -> DetailedFeedback: + """ + Generate detailed feedback from analysis result. + + Args: + result: The complexity analysis result + level: Level of detail for feedback + + Returns: + DetailedFeedback object with complete feedback + """ + feedback = DetailedFeedback( + summary=self._generate_summary(result), + complexity_result=result.time_complexity.value, + loop_count=len(result.loops), + max_nesting=result.max_nesting_depth, + has_recursion=result.recursion is not None + ) + + # Add analysis sections + feedback.sections = self._generate_sections(result) + + # Add complexity explanation + feedback.complexity_explanation = self._explain_complexity(result.time_complexity) + + # Add real-world example + feedback.real_world_example = self._get_real_world_example(result.time_complexity) + + # Generate suggestions + feedback.suggestions = self._generate_suggestions(result) + + # Add confidence note + feedback.confidence_note = self._generate_confidence_note(result) + + return feedback + + def _generate_summary(self, result: AnalysisResult) -> str: + """Generate a summary of the analysis.""" + complexity = result.time_complexity + info = self.COMPLEXITY_INFO.get(complexity, {}) + name = info.get("name", complexity.value) + + parts = [] + parts.append(f"Your algorithm has {name} complexity: {complexity.value}.") + + # Add context based on what was found + if result.recursion: + rec = result.recursion + if rec.branching_factor == 1: + parts.append(f"This is due to recursion in {rec.function_name}() with {rec.reduction_pattern} reduction.") + else: + parts.append(f"This is due to {rec.branching_factor}-way recursion in {rec.function_name}().") + elif result.loops: + if result.max_nesting_depth > 1: + parts.append(f"This is caused by {result.max_nesting_depth} levels of nested loops.") + elif len(result.loops) == 1: + loop = result.loops[0] + parts.append(f"This is due to a {loop.loop_type.upper()} loop that iterates {loop.iterations} times.") + else: + parts.append(f"This is due to {len(result.loops)} loops in your code.") + else: + parts.append("No loops or recursion detected - the algorithm runs in constant time.") + + return " ".join(parts) + + def _generate_sections(self, result: AnalysisResult) -> List[FeedbackSection]: + """Generate detailed analysis sections.""" + sections = [] + + # Loop analysis section + if result.loops: + sections.append(self._generate_loop_section(result)) + + # Recursion analysis section + if result.recursion: + sections.append(self._generate_recursion_section(result.recursion)) + + # Nesting analysis + if result.max_nesting_depth > 1: + sections.append(self._generate_nesting_section(result)) + + # Complexity breakdown + sections.append(self._generate_breakdown_section(result)) + + return sections + + def _generate_loop_section(self, result: AnalysisResult) -> FeedbackSection: + """Generate section explaining loop analysis.""" + lines = [] + + for i, loop in enumerate(result.loops, 1): + lines.append(f"Loop {i}: {loop.get_description()}") + lines.append(f" - Iterations: {loop.iterations}") + lines.append(f" - Complexity contribution: {loop.complexity.value}") + + if loop.nested_loops: + lines.append(f" - Contains {len(loop.nested_loops)} nested loop(s):") + for j, nested in enumerate(loop.nested_loops, 1): + lines.append(f" {j}. {nested.get_description()} - {nested.iterations} iterations") + + lines.append("") + + # Explain how loops combine + if len(result.loops) > 1 or any(l.nested_loops for l in result.loops): + lines.append("How loops combine:") + lines.append(" - Nested loops: complexities MULTIPLY (O(n) × O(n) = O(n²))") + lines.append(" - Sequential loops: take the MAXIMUM (O(n) + O(n) = O(n))") + + return FeedbackSection( + title="Loop Analysis", + content="\n".join(lines), + importance="info" + ) + + def _generate_recursion_section(self, rec: RecursionInfo) -> FeedbackSection: + """Generate section explaining recursion analysis.""" + lines = [] + + lines.append(f"Recursive Function: {rec.function_name}()") + lines.append(f" - Number of recursive calls: {rec.branching_factor} per invocation") + lines.append(f" - Problem reduction: {rec.reduction_pattern}") + lines.append(f" - Recurrence relation: {rec.recurrence}") + lines.append("") + + # Explain the recursion pattern + if rec.branching_factor == 1: + if "n/2" in rec.reduction_pattern: + lines.append("Pattern: Single recursive call with halving (like binary search)") + lines.append("Each call processes half the remaining problem.") + lines.append("Total calls: log₂(n), giving O(log n) complexity.") + else: + lines.append("Pattern: Linear recursion (like factorial)") + lines.append("Each call reduces problem by 1, requiring n calls total.") + lines.append("Result: O(n) complexity.") + elif rec.branching_factor == 2: + if "n/2" in rec.reduction_pattern: + lines.append("Pattern: Divide-and-conquer (like merge sort)") + lines.append("Problem splits in half, but both halves are processed.") + lines.append("Using Master Theorem: O(n log n) complexity.") + else: + lines.append("Pattern: Binary recursion with linear reduction (like naive Fibonacci)") + lines.append("WARNING: This creates an exponential number of calls!") + lines.append("Each call spawns 2 more, leading to O(2^n) complexity.") + else: + lines.append(f"Pattern: {rec.branching_factor}-way branching recursion") + lines.append("Multiple recursive calls lead to exponential growth.") + + return FeedbackSection( + title="Recursion Analysis", + content="\n".join(lines), + importance="warning" if rec.complexity in [ComplexityClass.EXPONENTIAL, ComplexityClass.FACTORIAL] else "info" + ) + + def _generate_nesting_section(self, result: AnalysisResult) -> FeedbackSection: + """Generate section about loop nesting.""" + depth = result.max_nesting_depth + + lines = [] + lines.append(f"Maximum nesting depth: {depth} levels") + lines.append("") + + # Explain impact + if depth == 2: + lines.append("Two nested loops typically result in O(n²) quadratic complexity.") + lines.append("Example: for i in 1..n: for j in 1..n: → n × n = n² operations") + elif depth == 3: + lines.append("Three nested loops result in O(n³) cubic complexity.") + lines.append("This grows VERY quickly - use with caution for large inputs!") + elif depth > 3: + lines.append(f"{depth} levels of nesting creates O(n^{depth}) complexity.") + lines.append("This is extremely slow for large inputs!") + lines.append("Consider if all these nested loops are necessary.") + + return FeedbackSection( + title="Nesting Impact", + content="\n".join(lines), + importance="warning" if depth >= 3 else "info" + ) + + def _generate_breakdown_section(self, result: AnalysisResult) -> FeedbackSection: + """Generate complexity breakdown section.""" + lines = [] + lines.append("Step-by-step complexity calculation:") + lines.append("") + + if not result.loops and not result.recursion: + lines.append("1. No loops or recursion detected") + lines.append("2. Only simple operations (assignments, comparisons)") + lines.append("3. Each operation takes constant time O(1)") + lines.append("→ Final complexity: O(1)") + elif result.recursion and not result.loops: + rec = result.recursion + lines.append(f"1. Found recursive function: {rec.function_name}()") + lines.append(f"2. Recurrence: {rec.recurrence}") + if "n/2" in rec.reduction_pattern: + lines.append("3. Problem halves each call → logarithmic depth") + if rec.branching_factor == 1: + lines.append("4. Single recursive call per level") + lines.append("→ Final complexity: O(log n)") + else: + lines.append(f"4. {rec.branching_factor} calls per level") + lines.append("5. Apply Master Theorem") + lines.append(f"→ Final complexity: {rec.complexity.value}") + else: + lines.append(f"3. Linear reduction ({rec.reduction_pattern}) → n levels deep") + if rec.branching_factor == 1: + lines.append("→ Final complexity: O(n)") + else: + lines.append(f"4. {rec.branching_factor}^n calls total") + lines.append(f"→ Final complexity: {rec.complexity.value}") + else: + # Loop-based complexity + step = 1 + current = ComplexityClass.CONSTANT + + for loop in result.loops: + lines.append(f"{step}. {loop.get_description()}") + lines.append(f" Iterates {loop.iterations} times → {loop.complexity.value}") + step += 1 + + for nested in loop.nested_loops: + lines.append(f"{step}. Nested: {nested.get_description()}") + lines.append(f" Iterates {nested.iterations} times → {nested.complexity.value}") + lines.append(f" Nested inside previous loop → multiply complexities") + step += 1 + + lines.append("") + lines.append(f"→ Final time complexity: {result.time_complexity.value}") + + return FeedbackSection( + title="Complexity Calculation", + content="\n".join(lines), + importance="success" + ) + + def _explain_complexity(self, complexity: ComplexityClass) -> str: + """Get explanation for a complexity class.""" + info = self.COMPLEXITY_INFO.get(complexity, {}) + + lines = [] + lines.append(f"{info.get('name', complexity.value)} - {info.get('symbol', '')}") + lines.append("") + lines.append(info.get('description', '')) + lines.append("") + lines.append(f"Growth rate: {info.get('growth', '')}") + + return "\n".join(lines) + + def _get_real_world_example(self, complexity: ComplexityClass) -> str: + """Get real-world example for complexity class.""" + info = self.COMPLEXITY_INFO.get(complexity, {}) + example = info.get('example', '') + + if not example: + return "" + + lines = [] + lines.append(f"Common algorithms with {complexity.value}:") + lines.append(f" {example}") + + # Add scale examples + if complexity == ComplexityClass.LINEAR: + lines.append("") + lines.append("At different scales:") + lines.append(" n=100: ~100 operations") + lines.append(" n=1,000: ~1,000 operations") + lines.append(" n=1,000,000: ~1,000,000 operations") + elif complexity == ComplexityClass.QUADRATIC: + lines.append("") + lines.append("At different scales:") + lines.append(" n=100: ~10,000 operations") + lines.append(" n=1,000: ~1,000,000 operations") + lines.append(" n=10,000: ~100,000,000 operations (slow!)") + elif complexity == ComplexityClass.LOGARITHMIC: + lines.append("") + lines.append("At different scales:") + lines.append(" n=100: ~7 operations") + lines.append(" n=1,000: ~10 operations") + lines.append(" n=1,000,000: ~20 operations") + elif complexity == ComplexityClass.EXPONENTIAL: + lines.append("") + lines.append("At different scales:") + lines.append(" n=10: ~1,024 operations") + lines.append(" n=20: ~1,048,576 operations") + lines.append(" n=30: ~1,073,741,824 operations (very slow!)") + lines.append(" n=50: More than age of universe in nanoseconds!") + + return "\n".join(lines) + + def _generate_suggestions(self, result: AnalysisResult) -> List[str]: + """Generate optimization suggestions.""" + suggestions = [] + complexity = result.time_complexity + + # Suggestions based on complexity + if complexity == ComplexityClass.EXPONENTIAL: + suggestions.append("Consider using dynamic programming to avoid redundant calculations.") + suggestions.append("Memoization can often convert O(2^n) to O(n) or O(n²).") + if result.recursion and result.recursion.branching_factor >= 2: + suggestions.append(f"Your {result.recursion.function_name}() has overlapping subproblems - cache results!") + + if complexity == ComplexityClass.QUADRATIC and result.max_nesting_depth >= 2: + suggestions.append("Look for ways to eliminate one of the nested loops.") + suggestions.append("Consider using a hash table/dictionary to replace the inner loop with O(1) lookup.") + suggestions.append("Sorting first might enable a more efficient algorithm.") + + if complexity == ComplexityClass.CUBIC: + suggestions.append("Triple nested loops are usually avoidable with better algorithms.") + suggestions.append("For matrix operations, consider Strassen's algorithm or library functions.") + + # Suggestions based on structure + if result.max_nesting_depth >= 3: + suggestions.append("Deep nesting makes code hard to read and slow. Consider refactoring.") + + if result.recursion and not result.loops: + if result.recursion.branching_factor == 1: + suggestions.append("This recursion could be converted to a simple loop for better performance.") + + # Generic good practices + if not suggestions: + if complexity in [ComplexityClass.CONSTANT, ComplexityClass.LOGARITHMIC]: + suggestions.append("Your algorithm is already very efficient!") + elif complexity == ComplexityClass.LINEAR: + suggestions.append("Linear complexity is often optimal for problems requiring examination of all input.") + elif complexity == ComplexityClass.LINEARITHMIC: + suggestions.append("O(n log n) is optimal for comparison-based sorting. Good job!") + + return suggestions + + def _generate_confidence_note(self, result: AnalysisResult) -> str: + """Generate note about analysis confidence.""" + confidence = result.confidence + + if confidence >= 0.9: + return "High confidence: Clear loop/recursion patterns detected." + elif confidence >= 0.7: + return "Moderate confidence: Analysis based on detected patterns." + else: + return "Lower confidence: Some constructs may not have been fully analyzed." + + def format_for_student(self, result: AnalysisResult) -> str: + """ + Format feedback specifically for student learning. + + Returns a clear, educational explanation suitable for students + learning about algorithm complexity. + """ + feedback = self.generate(result, FeedbackLevel.DETAILED) + return feedback.to_string(FeedbackLevel.DETAILED) + + def format_brief(self, result: AnalysisResult) -> str: + """Get brief one-line feedback.""" + feedback = self.generate(result, FeedbackLevel.BRIEF) + return f"Time Complexity: {result.time_complexity.value} - {feedback.summary}" diff --git a/evaluation_function/evaluation.py b/evaluation_function/evaluation.py index 61ecaa3..9aaa1cf 100755 --- a/evaluation_function/evaluation.py +++ b/evaluation_function/evaluation.py @@ -1,34 +1,412 @@ -from typing import Any +""" +Evaluation function for pseudocode complexity analysis. + +This module implements the main evaluation pipeline that: +1. Parses the student's pseudocode +2. Analyzes its time and space complexity +3. Compares against the expected complexity bound +4. The code is CORRECT if its complexity is <= the expected bound + +The answer specifies a complexity upper bound that the student's code must meet. +""" + +from typing import Any, Dict, Optional, Tuple from lf_toolkit.evaluation import Result, Params +from .parser.parser import PseudocodeParser +from .analyzer.complexity_analyzer import ComplexityAnalyzer, AnalysisResult +from .schemas.complexity import ComplexityClass + + def evaluation_function( response: Any, answer: Any, params: Params, ) -> Result: """ - Function used to evaluate a student response. - --- - The handler function passes three arguments to evaluation_function(): - - - `response` which are the answers provided by the student. - - `answer` which are the correct answers to compare against. - - `params` which are any extra parameters that may be useful, - e.g., error tolerances. - - The output of this function is what is returned as the API response - and therefore must be JSON-encodable. It must also conform to the - response schema. - - Any standard python library may be used, as well as any package - available on pip (provided it is added to requirements.txt). - - The way you wish to structure you code (all in this function, or - split into many) is entirely up to you. All that matters are the - return types and that evaluation_function() is the main function used - to output the evaluation response. + Evaluate a student's pseudocode complexity. + + The evaluation checks if the student's code has complexity within + the expected bound. A student's answer is CORRECT if their code's + complexity is less than or equal to the expected complexity. + + Args: + response: Student's response containing: + - pseudocode: The pseudocode string + - time_complexity (optional): Student's stated time complexity + - space_complexity (optional): Student's stated space complexity + answer: Expected answer containing: + - expected_time_complexity: Upper bound for time complexity (e.g., "O(n^2)") + - expected_space_complexity (optional): Upper bound for space complexity + params: Evaluation parameters for customization + + Returns: + Result with is_correct, score, and detailed feedback + """ + try: + # Parse inputs + pseudocode, student_time, student_space = _parse_response(response) + expected_time, expected_space, eval_options = _parse_answer(answer, params) + + # Validate inputs + if not pseudocode: + result = Result(is_correct=False) + result.add_feedback("error", "No pseudocode provided. Please submit your algorithm.") + return result + + # Parse and analyze the pseudocode + parser = PseudocodeParser() + parse_result = parser.parse(pseudocode) + + if not parse_result.success and eval_options.get('strict_parsing', False): + result = Result(is_correct=False) + result.add_feedback("error", _format_parse_error(parse_result)) + return result + + # Analyze complexity + analyzer = ComplexityAnalyzer() + analysis = analyzer.analyze(pseudocode, parse_result.ast) + + # Evaluate time complexity + time_result = _evaluate_complexity( + detected=analysis.time_complexity, + expected_bound=expected_time, + student_stated=student_time, + complexity_type="time" + ) + + # Evaluate space complexity (optional) + space_result = None + if expected_space: + space_result = _evaluate_complexity( + detected=analysis.space_complexity, + expected_bound=expected_space, + student_stated=student_space, + complexity_type="space" + ) + + # Calculate overall correctness and score + is_correct, score = _calculate_result(time_result, space_result, eval_options) + + # Generate feedback + feedback = _generate_feedback( + time_result=time_result, + space_result=space_result, + analysis=analysis, + is_correct=is_correct, + eval_options=eval_options + ) + + # Build result + result = Result(is_correct=is_correct) + result.add_feedback("complexity", feedback) + + return result + + except Exception as e: + result = Result(is_correct=False) + result.add_feedback("error", f"An error occurred during evaluation: {str(e)}") + return result + + +def _parse_response(response: Any) -> Tuple[str, Optional[str], Optional[str]]: + """Parse the student's response to extract pseudocode and stated complexities.""" + if isinstance(response, str): + return response, None, None + + if isinstance(response, dict): + pseudocode = response.get('pseudocode', response.get('code', '')) + time_complexity = response.get('time_complexity') + space_complexity = response.get('space_complexity') + return pseudocode, time_complexity, space_complexity + + return '', None, None + + +def _parse_answer(answer: Any, params: Params) -> Tuple[ComplexityClass, Optional[ComplexityClass], Dict]: + """Parse the expected answer and evaluation options.""" + eval_options = {} + + # Handle params + if hasattr(params, '__iter__'): + for key in params: + eval_options[key] = params[key] + elif hasattr(params, 'to_dict'): + eval_options = params.to_dict() + + # Parse answer + if isinstance(answer, str): + expected_time = ComplexityClass.from_string(answer) + expected_space = None + elif isinstance(answer, dict): + expected_time = ComplexityClass.from_string( + answer.get('expected_time_complexity', answer.get('time_complexity', 'O(n)')) + ) + expected_space_str = answer.get('expected_space_complexity', answer.get('space_complexity')) + expected_space = ComplexityClass.from_string(expected_space_str) if expected_space_str else None + + # Merge answer options into eval_options + for key in ['show_detailed_feedback', 'strict_parsing', 'partial_credit']: + if key in answer: + eval_options[key] = answer[key] + else: + expected_time = ComplexityClass.LINEAR + expected_space = None + + return expected_time, expected_space, eval_options + + +def _evaluate_complexity( + detected: ComplexityClass, + expected_bound: ComplexityClass, + student_stated: Optional[str], + complexity_type: str +) -> Dict: + """ + Evaluate if detected complexity meets the expected bound. + + Returns dict with: + - is_correct: True if detected <= expected_bound + - detected: The detected complexity + - expected: The expected bound + - comparison: -1 (better), 0 (equal), 1 (worse) + - student_stated_correct: If student's stated answer matches detected + """ + # Compare complexities: correct if detected <= expected + comparison = ComplexityClass.compare(detected, expected_bound) + is_correct = comparison <= 0 # detected is same or better than bound + + # Check if student's stated complexity matches detected + student_stated_correct = None + if student_stated: + stated_class = ComplexityClass.from_string(student_stated) + student_stated_correct = stated_class == detected + + return { + 'is_correct': is_correct, + 'detected': detected, + 'expected': expected_bound, + 'comparison': comparison, + 'student_stated': student_stated, + 'student_stated_correct': student_stated_correct, + 'type': complexity_type + } + + +def _calculate_result( + time_result: Dict, + space_result: Optional[Dict], + eval_options: Dict +) -> Tuple[bool, float]: + """Calculate overall correctness and score.""" + time_weight = eval_options.get('time_weight', 0.7) + space_weight = eval_options.get('space_weight', 0.3) + partial_credit = eval_options.get('partial_credit', True) + + # If no space requirement, only consider time + if space_result is None: + is_correct = time_result['is_correct'] + if partial_credit and not is_correct: + # Give partial credit based on how close they are + score = _partial_score(time_result['detected'], time_result['expected']) + else: + score = 1.0 if is_correct else 0.0 + return is_correct, score + + # Both time and space required + is_correct = time_result['is_correct'] and space_result['is_correct'] + + if partial_credit: + time_score = 1.0 if time_result['is_correct'] else _partial_score( + time_result['detected'], time_result['expected'] + ) + space_score = 1.0 if space_result['is_correct'] else _partial_score( + space_result['detected'], space_result['expected'] + ) + score = time_weight * time_score + space_weight * space_score + else: + score = 1.0 if is_correct else 0.0 + + return is_correct, score + + +def _partial_score(detected: ComplexityClass, expected: ComplexityClass) -> float: + """Calculate partial credit score based on complexity difference.""" + order = ComplexityClass.get_order() + + try: + detected_idx = order.index(detected) + expected_idx = order.index(expected) + except ValueError: + return 0.0 + + if detected_idx <= expected_idx: + return 1.0 # Met or exceeded requirement + + # Calculate partial credit: decreasing score for each level above expected + diff = detected_idx - expected_idx + # Score decreases by 0.2 for each complexity level above expected + # Max partial credit is 0.5 for being one level above + return max(0.0, 0.5 - (diff - 1) * 0.15) + + +def _generate_feedback( + time_result: Dict, + space_result: Optional[Dict], + analysis: AnalysisResult, + is_correct: bool, + eval_options: Dict +) -> str: + """Generate comprehensive feedback for the student.""" + lines = [] + + # Overall result + if is_correct: + lines.append("✓ Correct! Your algorithm meets the complexity requirements.") + else: + lines.append("✗ Your algorithm does not meet the complexity requirements.") + lines.append("") + + # Time complexity feedback + time_correct = time_result['is_correct'] + lines.append("Time Complexity:") + lines.append(f" • Required: {time_result['expected'].value} or better") + lines.append(f" • Detected: {time_result['detected'].value}") + + if time_correct: + if time_result['comparison'] < 0: + lines.append(f" ✓ Excellent! Your algorithm is more efficient than required.") + else: + lines.append(f" ✓ Your algorithm meets the time complexity requirement.") + else: + lines.append(f" ✗ Your algorithm exceeds the allowed time complexity.") + lines.append(f" Try to optimize your algorithm to achieve {time_result['expected'].value}.") + + # Student's stated complexity feedback + if time_result.get('student_stated'): + if time_result.get('student_stated_correct'): + lines.append(f" ✓ Your stated complexity ({time_result['student_stated']}) matches the detected complexity.") + else: + lines.append(f" ⚠ Your stated complexity ({time_result['student_stated']}) differs from detected ({time_result['detected'].value}).") + + lines.append("") + + # Space complexity feedback (if applicable) + if space_result: + space_correct = space_result['is_correct'] + lines.append("Space Complexity:") + lines.append(f" • Required: {space_result['expected'].value} or better") + lines.append(f" • Detected: {space_result['detected'].value}") + + if space_correct: + lines.append(f" ✓ Your algorithm meets the space complexity requirement.") + else: + lines.append(f" ✗ Your algorithm exceeds the allowed space complexity.") + lines.append("") + + # Detailed analysis (if enabled) + if eval_options.get('show_detailed_feedback', True) and not is_correct: + lines.append("-" * 50) + lines.append("Analysis Details:") + + if analysis.loops: + lines.append(f" • Found {len(analysis.loops)} loop(s)") + if analysis.max_nesting_depth > 1: + lines.append(f" • Maximum nesting depth: {analysis.max_nesting_depth}") + + if analysis.recursion: + rec = analysis.recursion + lines.append(f" • Recursive function: {rec.function_name}()") + lines.append(f" • Recurrence: {rec.recurrence}") + + # Suggestions + lines.append("") + lines.append("Suggestions:") + if time_result['detected'] == ComplexityClass.QUADRATIC and time_result['expected'] == ComplexityClass.LINEAR: + lines.append(" • Consider if you can eliminate one of the nested loops") + lines.append(" • Using a hash table/dictionary might help reduce complexity") + elif time_result['detected'] == ComplexityClass.EXPONENTIAL: + lines.append(" • Consider using dynamic programming or memoization") + lines.append(" • Avoid redundant recursive calls by caching results") + elif time_result['detected'] == ComplexityClass.CUBIC: + lines.append(" • Triple nested loops are often avoidable") + lines.append(" • Consider algorithmic improvements like divide-and-conquer") + + return "\n".join(lines) + + +def _format_parse_error(parse_result) -> str: + """Format parsing errors for feedback.""" + lines = ["Failed to parse the pseudocode."] + + if parse_result.errors: + lines.append("\nErrors:") + for error in parse_result.errors: + lines.append(f" • {error}") + + if parse_result.warnings: + lines.append("\nWarnings:") + for warning in parse_result.warnings: + lines.append(f" • {warning}") + + lines.append("\nPlease check your pseudocode syntax and try again.") + return "\n".join(lines) + + +def _build_feedback_items( + time_result: Dict, + space_result: Optional[Dict], + is_correct: bool, + score: float +) -> list: + """ + Build feedback items as list of (level, message) tuples. + + Levels: 'success', 'warning', 'error', 'info' """ + from typing import List, Tuple + items: List[Tuple[str, str]] = [] + + # Overall result + if is_correct: + items.append(("success", "Your algorithm meets the complexity requirements.")) + else: + items.append(("error", "Your algorithm does not meet the complexity requirements.")) + + # Time complexity feedback + time_correct = time_result['is_correct'] + detected = time_result['detected'].value + expected = time_result['expected'].value + + if time_correct: + if time_result['comparison'] < 0: + items.append(("success", f"Time complexity: {detected} (better than required {expected})")) + else: + items.append(("success", f"Time complexity: {detected} (meets requirement)")) + else: + items.append(("error", f"Time complexity: {detected} (exceeds required {expected})")) + items.append(("info", f"Try to optimize your algorithm to achieve {expected} or better.")) + + # Student's stated complexity feedback + if time_result.get('student_stated'): + if time_result.get('student_stated_correct'): + items.append(("success", f"Your stated complexity ({time_result['student_stated']}) matches the detected complexity.")) + else: + items.append(("warning", f"Your stated complexity ({time_result['student_stated']}) differs from detected ({detected}).")) + + # Space complexity feedback (if applicable) + if space_result: + space_correct = space_result['is_correct'] + space_detected = space_result['detected'].value + space_expected = space_result['expected'].value + + if space_correct: + items.append(("success", f"Space complexity: {space_detected} (meets requirement)")) + else: + items.append(("error", f"Space complexity: {space_detected} (exceeds required {space_expected})")) + + # Score info + if not is_correct and score > 0: + items.append(("info", f"Partial credit: {score:.0%}")) - return Result( - is_correct=response == answer - ) \ No newline at end of file + return items diff --git a/evaluation_function/evaluation_test.py b/evaluation_function/evaluation_test.py deleted file mode 100755 index 7a5c5bd..0000000 --- a/evaluation_function/evaluation_test.py +++ /dev/null @@ -1,30 +0,0 @@ -import unittest - -from .evaluation import Params, evaluation_function - -class TestEvaluationFunction(unittest.TestCase): - """ - TestCase Class used to test the algorithm. - --- - Tests are used here to check that the algorithm written - is working as it should. - - It's best practise to write these tests first to get a - kind of 'specification' for how your algorithm should - work, and you should run these tests before committing - your code to AWS. - - Read the docs on how to use unittest here: - https://docs.python.org/3/library/unittest.html - - Use evaluation_function() to check your algorithm works - as it should. - """ - - def test_evaluation(self): - response, answer, params = "Hello, World", "Hello, World", Params() - - result = evaluation_function(response, answer, params).to_dict() - - self.assertEqual(result.get("is_correct"), True) - self.assertFalse(result.get("feedback", False)) diff --git a/evaluation_function/parser/__init__.py b/evaluation_function/parser/__init__.py new file mode 100644 index 0000000..50ef47d --- /dev/null +++ b/evaluation_function/parser/__init__.py @@ -0,0 +1,30 @@ +""" +Parser module for pseudocode analysis. + +This module provides: +- Preprocessor: Normalizes pseudocode syntax variations +- Grammar: Lark grammar for pseudocode parsing +- ASTBuilder: Transforms parse tree to AST nodes (when Lark available) +- PseudocodeParser: Main parser interface +""" + +from .preprocessor import Preprocessor, PreprocessorConfig +from .grammar import PSEUDOCODE_GRAMMAR, SIMPLIFIED_GRAMMAR +from .parser import PseudocodeParser, ParseError, ParserConfig + +# ASTBuilder is only available if Lark is installed +try: + from .ast_builder import ASTBuilder +except ImportError: + ASTBuilder = None + +__all__ = [ + "Preprocessor", + "PreprocessorConfig", + "PSEUDOCODE_GRAMMAR", + "SIMPLIFIED_GRAMMAR", + "ASTBuilder", + "PseudocodeParser", + "ParseError", + "ParserConfig", +] diff --git a/evaluation_function/parser/ast_builder.py b/evaluation_function/parser/ast_builder.py new file mode 100644 index 0000000..518dcd4 --- /dev/null +++ b/evaluation_function/parser/ast_builder.py @@ -0,0 +1,507 @@ +""" +AST Builder - Transforms Lark parse tree to AST nodes. + +This module provides the Transformer class that converts Lark's parse tree +into our custom AST node types defined in schemas.ast_nodes. +""" + +from typing import List, Optional, Any + +# Try to import lark, but don't fail if not available +try: + from lark import Transformer, v_args, Token, Tree + LARK_AVAILABLE = True +except ImportError: + LARK_AVAILABLE = False + # Dummy classes for when lark is not available + class Transformer: + pass + def v_args(inline=False): + def decorator(cls): + return cls + return decorator + Token = None + Tree = None + +from ..schemas.ast_nodes import ( + ProgramNode, + FunctionNode, + BlockNode, + LoopNode, + ConditionalNode, + AssignmentNode, + ReturnNode, + FunctionCallNode, + ExpressionNode, + VariableNode, + LiteralNode, + BinaryOpNode, + UnaryOpNode, + ArrayAccessNode, + SourceLocation, + NodeType, + LoopType, + OperatorType, +) + + +@v_args(inline=True) +class ASTBuilder(Transformer): + """ + Transforms Lark parse tree into AST nodes. + + Uses the @v_args(inline=True) decorator to receive children as arguments + instead of a list. + """ + + def __init__(self): + super().__init__() + self.current_function: Optional[str] = None + self.function_names: List[str] = [] + + # ========================================================================= + # Program Structure + # ========================================================================= + + def start(self, *items) -> ProgramNode: + """Build the root program node.""" + functions = [] + statements = [] + + for item in items: + if item is None: + continue + if isinstance(item, FunctionNode): + functions.append(item) + elif isinstance(item, list): + statements.extend(item) + else: + statements.append(item) + + global_block = BlockNode(statements=statements) if statements else None + + return ProgramNode( + functions=functions, + global_statements=global_block + ) + + def function_def(self, keyword, name, *rest) -> FunctionNode: + """Build a function definition node.""" + params = [] + body = None + + for item in rest: + if isinstance(item, list) and all(isinstance(p, VariableNode) for p in item): + params = item + elif isinstance(item, BlockNode): + body = item + elif isinstance(item, list): + body = BlockNode(statements=item) + + func_name = str(name) + self.function_names.append(func_name) + + return FunctionNode( + name=func_name, + parameters=params, + body=body + ) + + def function_keyword(self, keyword) -> str: + """Extract function keyword.""" + return str(keyword) + + def param_list(self, *params) -> List[VariableNode]: + """Build parameter list.""" + return list(params) + + def param(self, name, *rest) -> VariableNode: + """Build a parameter node.""" + return VariableNode(name=str(name)) + + # ========================================================================= + # Statements + # ========================================================================= + + def statement(self, stmt) -> Any: + """Pass through statement.""" + return stmt + + def block_body(self, *statements) -> BlockNode: + """Build a block of statements.""" + stmts = [] + for s in statements: + if s is None: + continue + if isinstance(s, list): + stmts.extend(s) + else: + stmts.append(s) + return BlockNode(statements=stmts) + + def assignment(self, target, value) -> AssignmentNode: + """Build an assignment node.""" + return AssignmentNode( + target=target, + value=value, + operator=OperatorType.ASSIGN + ) + + def lvalue(self, val) -> Any: + """Extract lvalue (variable or array access).""" + return val + + # ========================================================================= + # Control Flow - Conditionals + # ========================================================================= + + def if_stmt(self, condition, *rest) -> ConditionalNode: + """Build an if statement node.""" + then_branch = None + else_branch = None + elif_branches = [] + + for item in rest: + if item is None: + continue + if isinstance(item, ConditionalNode): + elif_branches.append(item) + elif isinstance(item, BlockNode): + if then_branch is None: + then_branch = item + else: + else_branch = item + elif isinstance(item, list): + block = BlockNode(statements=item) + if then_branch is None: + then_branch = block + else: + else_branch = block + + return ConditionalNode( + condition=condition, + then_branch=then_branch, + else_branch=else_branch, + elif_branches=elif_branches + ) + + def elif_clause(self, condition, *rest) -> ConditionalNode: + """Build an elif clause as a ConditionalNode.""" + body = None + for item in rest: + if isinstance(item, BlockNode): + body = item + elif isinstance(item, list): + body = BlockNode(statements=item) + + return ConditionalNode( + condition=condition, + then_branch=body + ) + + def else_clause(self, *rest) -> BlockNode: + """Build else clause body.""" + for item in rest: + if isinstance(item, BlockNode): + return item + elif isinstance(item, list): + return BlockNode(statements=item) + return BlockNode(statements=[]) + + def then_clause(self, *args) -> None: + """Then clause is just syntax, return None.""" + return None + + def end_clause(self, *args) -> None: + """End clause is just syntax, return None.""" + return None + + def inline_statement(self, stmt) -> Any: + """Pass through inline statement.""" + return stmt + + # ========================================================================= + # Control Flow - Loops + # ========================================================================= + + def for_loop(self, header, body, *rest) -> LoopNode: + """Build a for loop node.""" + # Header contains the loop details + loop = header + + # Set the body + if isinstance(body, BlockNode): + loop.body = body + elif isinstance(body, list): + loop.body = BlockNode(statements=body) + + return loop + + def for_header(self, *args) -> LoopNode: + """Parse for loop header and build LoopNode.""" + # Extract components from args + iterator = None + start = None + end = None + step = None + + for arg in args: + if isinstance(arg, Token) and arg.type == 'NAME': + if iterator is None: + iterator = VariableNode(name=str(arg)) + elif isinstance(arg, VariableNode): + if iterator is None: + iterator = arg + elif isinstance(arg, ExpressionNode) or isinstance(arg, (int, float)): + if start is None: + start = arg if isinstance(arg, ExpressionNode) else LiteralNode(value=arg, literal_type="number") + elif end is None: + end = arg if isinstance(arg, ExpressionNode) else LiteralNode(value=arg, literal_type="number") + elif step is None: + step = arg if isinstance(arg, ExpressionNode) else LiteralNode(value=arg, literal_type="number") + + return LoopNode( + loop_type=LoopType.FOR, + iterator=iterator, + start=start, + end=end, + step=step + ) + + def step_clause(self, value) -> ExpressionNode: + """Extract step value.""" + return value + + def do_clause(self, *args) -> None: + """Do clause is just syntax, return None.""" + return None + + def while_loop(self, condition, *rest) -> LoopNode: + """Build a while loop node.""" + body = None + for item in rest: + if isinstance(item, BlockNode): + body = item + elif isinstance(item, list): + body = BlockNode(statements=item) + + return LoopNode( + loop_type=LoopType.WHILE, + condition=condition, + body=body + ) + + def repeat_loop(self, *rest) -> LoopNode: + """Build a repeat-until loop node.""" + body = None + condition = None + + for item in rest: + if isinstance(item, BlockNode): + body = item + elif isinstance(item, list): + body = BlockNode(statements=item) + elif isinstance(item, ExpressionNode): + condition = item + + return LoopNode( + loop_type=LoopType.REPEAT_UNTIL, + condition=condition, + body=body + ) + + def foreach_loop(self, *args) -> LoopNode: + """Build a for-each loop node.""" + iterator = None + collection = None + body = None + + for arg in args: + if isinstance(arg, Token) and arg.type == 'NAME': + if iterator is None: + iterator = VariableNode(name=str(arg)) + elif collection is None: + collection = VariableNode(name=str(arg)) + elif isinstance(arg, VariableNode): + if iterator is None: + iterator = arg + elif collection is None: + collection = arg + elif isinstance(arg, ExpressionNode): + collection = arg + elif isinstance(arg, BlockNode): + body = arg + elif isinstance(arg, list): + body = BlockNode(statements=arg) + + return LoopNode( + loop_type=LoopType.FOR_EACH, + iterator=iterator, + collection=collection, + body=body + ) + + # ========================================================================= + # Other Statements + # ========================================================================= + + def return_stmt(self, value=None) -> ReturnNode: + """Build a return statement node.""" + return ReturnNode(value=value) + + def function_call_stmt(self, call) -> FunctionCallNode: + """Function call as statement.""" + return call + + def print_stmt(self, *args) -> FunctionCallNode: + """Build a print statement as function call.""" + arguments = [a for a in args if isinstance(a, ExpressionNode)] + return FunctionCallNode( + function_name="print", + arguments=arguments + ) + + def swap_stmt(self, *args) -> FunctionCallNode: + """Build a swap statement as function call.""" + arguments = [a for a in args if isinstance(a, ExpressionNode)] + return FunctionCallNode( + function_name="swap", + arguments=arguments + ) + + def block_stmt(self, body) -> BlockNode: + """Build explicit block statement.""" + return body if isinstance(body, BlockNode) else BlockNode(statements=[body]) + + # ========================================================================= + # Expressions - Binary Operations + # ========================================================================= + + def or_op(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.OR, left=left, right=right) + + def and_op(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.AND, left=left, right=right) + + def not_op(self, operand) -> UnaryOpNode: + return UnaryOpNode(operator=OperatorType.NOT, operand=operand) + + def eq(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.EQUAL, left=left, right=right) + + def ne(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.NOT_EQUAL, left=left, right=right) + + def lt(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.LESS_THAN, left=left, right=right) + + def le(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.LESS_EQUAL, left=left, right=right) + + def gt(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.GREATER_THAN, left=left, right=right) + + def ge(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.GREATER_EQUAL, left=left, right=right) + + def add(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.ADD, left=left, right=right) + + def sub(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.SUBTRACT, left=left, right=right) + + def mul(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.MULTIPLY, left=left, right=right) + + def div(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.DIVIDE, left=left, right=right) + + def floordiv(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.FLOOR_DIVIDE, left=left, right=right) + + def mod(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.MODULO, left=left, right=right) + + def pow(self, left, right) -> BinaryOpNode: + return BinaryOpNode(operator=OperatorType.POWER, left=left, right=right) + + def neg(self, operand) -> UnaryOpNode: + return UnaryOpNode(operator=OperatorType.SUBTRACT, operand=operand) + + def pos(self, operand) -> ExpressionNode: + return operand # Unary + doesn't change anything + + # ========================================================================= + # Expressions - Atoms + # ========================================================================= + + def number(self, token) -> LiteralNode: + """Build a number literal.""" + value = float(token) if '.' in str(token) else int(token) + return LiteralNode( + value=value, + literal_type="float" if isinstance(value, float) else "int" + ) + + def string(self, token) -> LiteralNode: + """Build a string literal.""" + # Remove quotes + value = str(token)[1:-1] + return LiteralNode(value=value, literal_type="string") + + def true(self) -> LiteralNode: + return LiteralNode(value=True, literal_type="bool") + + def false(self) -> LiteralNode: + return LiteralNode(value=False, literal_type="bool") + + def null(self) -> LiteralNode: + return LiteralNode(value=None, literal_type="null") + + def variable(self, token) -> VariableNode: + """Build a variable reference.""" + return VariableNode(name=str(token)) + + def array_access(self, array, index) -> ArrayAccessNode: + """Build an array access node.""" + if isinstance(array, Token): + array = VariableNode(name=str(array)) + return ArrayAccessNode(array=array, index=index) + + def function_call(self, name, *args) -> FunctionCallNode: + """Build a function call node.""" + arguments = [] + for arg in args: + if isinstance(arg, list): + arguments.extend(arg) + elif isinstance(arg, ExpressionNode): + arguments.append(arg) + + func_name = str(name) + is_recursive = func_name in self.function_names + + return FunctionCallNode( + function_name=func_name, + arguments=arguments, + is_recursive=is_recursive + ) + + def arg_list(self, *args) -> List[ExpressionNode]: + """Build argument list.""" + return [a for a in args if isinstance(a, ExpressionNode)] + + # ========================================================================= + # Utility + # ========================================================================= + + def NAME(self, token) -> str: + """Extract name from token.""" + return str(token) + + def NUMBER(self, token) -> LiteralNode: + """Convert NUMBER token.""" + return self.number(token) + + def STRING(self, token) -> LiteralNode: + """Convert STRING token.""" + return self.string(token) diff --git a/evaluation_function/parser/grammar.py b/evaluation_function/parser/grammar.py new file mode 100644 index 0000000..bec1e9b --- /dev/null +++ b/evaluation_function/parser/grammar.py @@ -0,0 +1,139 @@ +""" +Lark grammar definition for pseudocode parsing. + +This grammar is designed to be flexible and handle various pseudocode styles. +It uses a simplified approach to avoid LALR conflicts. +""" + +# Simplified Lark grammar for pseudocode +# Focuses on structure detection rather than full semantic parsing +PSEUDOCODE_GRAMMAR = r''' +start: (statement | function_def | _NL)* + +// Function definition +function_def: FUNC_KW NAME "(" [params] ")" block +FUNC_KW: "function"i | "algorithm"i | "procedure"i | "def"i + +params: NAME ("," NAME)* + +// Block (indentation, end-delimited, or curly braces) +block: _NL _INDENT statement+ _DEDENT + | _NL (statement _NL?)* END_KW _NL? + | "{" _NL? (statement _NL?)* "}" + +END_KW: "end"i NAME? | "endif"i | "endfor"i | "endwhile"i | "done"i + +// Statements +statement: for_stmt + | while_stmt + | if_stmt + | repeat_stmt + | return_stmt + | call_stmt + | assignment + | expr + +// Call statement (standalone function call with CALL keyword) +call_stmt: "call"i NAME "(" [args] ")" + +// For loop +for_stmt: "for"i NAME "=" expr "to"i expr ("step"i expr)? block + | "for"i NAME "=" expr "downto"i expr ("step"i expr)? block + | "for"i "each"i? NAME "in"i expr block + +// While loop +while_stmt: "while"i expr block + +// Repeat until +repeat_stmt: "repeat"i block "until"i expr + +// If statement +if_stmt: "if"i expr "then"i? block ("elif"i expr "then"i? block)* ("else"i block)? + +// Return statement +return_stmt: "return"i expr? + +// Assignment +assignment: NAME "=" expr + | NAME "[" expr "]" "=" expr + +// Expressions +?expr: or_expr + +?or_expr: and_expr (("or"i | "||") and_expr)* + +?and_expr: not_expr (("and"i | "&&") not_expr)* + +?not_expr: "not"i not_expr -> not_op + | "!" not_expr -> not_op + | comparison + +?comparison: arith (COMP_OP arith)* +COMP_OP: "==" | "!=" | "<=" | ">=" | "<" | ">" | "=" + +?arith: term (("+"|"-") term)* + +?term: factor (("*"|"/"|"//"|"%") factor)* + +?factor: power ("^" power)* + | "-" factor -> neg + | "+" factor + +?power: atom + +?atom: NUMBER + | STRING + | "true"i -> true + | "false"i -> false + | "call"i NAME "(" [args] ")" -> func_call + | NAME "(" [args] ")" -> func_call + | NAME "[" expr "]" -> array_access + | NAME -> var + | "(" expr ")" + +args: expr ("," expr)* + +// Terminals +NAME: /[a-zA-Z_][a-zA-Z0-9_]*/ +NUMBER: /\d+(\.\d+)?/ +STRING: /"[^"]*"/ | /'[^']*'/ + +// Whitespace +_NL: /(\r?\n[\t ]*)+/ +_INDENT: "" +_DEDENT: "" + +// Comments +COMMENT: "//" /[^\n]*/ | "#" /[^\n]*/ + +%ignore COMMENT +%ignore /[\t \f]+/ +''' + + +# Simplified grammar for fallback - focuses on structure detection only +SIMPLIFIED_GRAMMAR = r''' +start: line* + +line: _NL + | loop_line + | conditional_line + | function_line + | return_line + | other_line + +loop_line: LOOP_KEYWORD /[^\n]*/ +conditional_line: COND_KEYWORD /[^\n]*/ +function_line: FUNC_KEYWORD /[^\n]*/ +return_line: RETURN_KEYWORD /[^\n]*/ +other_line: /[^\n]+/ + +LOOP_KEYWORD: /\b(for|while|repeat|do|loop)\b/i +COND_KEYWORD: /\b(if|else|elif|then)\b/i +FUNC_KEYWORD: /\b(function|algorithm|procedure|def)\b/i +RETURN_KEYWORD: /\b(return)\b/i + +_NL: /\r?\n/ + +%ignore /[\t \f]+/ +''' diff --git a/evaluation_function/parser/parser.py b/evaluation_function/parser/parser.py new file mode 100644 index 0000000..c9ed980 --- /dev/null +++ b/evaluation_function/parser/parser.py @@ -0,0 +1,569 @@ +""" +Main Parser module for pseudocode. + +This module provides the PseudocodeParser class that combines: +- Preprocessing (syntax normalization) +- Lark parsing (grammar-based parsing) +- Fallback parsing (pattern-based for when grammar fails) +""" + +from typing import Tuple, List, Optional, Any +from dataclasses import dataclass +import re + +from .preprocessor import Preprocessor, PreprocessorConfig +from .grammar import PSEUDOCODE_GRAMMAR, SIMPLIFIED_GRAMMAR + +from ..schemas.ast_nodes import ( + ProgramNode, + FunctionNode, + BlockNode, + LoopNode, + ConditionalNode, + VariableNode, + LiteralNode, + LoopType, +) +from ..schemas.output_schema import ParseResult + + +class ParseError(Exception): + """Exception raised when parsing fails.""" + + def __init__(self, message: str, line: int = 0, column: int = 0, + context: Optional[str] = None): + self.message = message + self.line = line + self.column = column + self.context = context + super().__init__(self._format_message()) + + def _format_message(self) -> str: + msg = self.message + if self.line: + msg += f" at line {self.line}" + if self.column: + msg += f", column {self.column}" + if self.context: + msg += f"\n Context: {self.context}" + return msg + + +@dataclass +class ParserConfig: + """Configuration for the parser.""" + use_indentation: bool = True + strict_mode: bool = False + max_errors: int = 10 + timeout: float = 5.0 + + +class PseudocodeParser: + """ + Main parser for pseudocode. + + Uses pattern-based fallback parsing for robustness since full grammar + parsing of arbitrary pseudocode is challenging. + + Usage: + parser = PseudocodeParser() + result = parser.parse("FOR i = 1 TO n DO\\n print(i)") + if result.success: + ast = result.ast + """ + + def __init__(self, config: Optional[ParserConfig] = None, + preprocessor_config: Optional[PreprocessorConfig] = None): + self.config = config or ParserConfig() + self.preprocessor = Preprocessor(preprocessor_config) + self._lark_available = False + self._parser = None + + # Try to initialize Lark parser, but don't fail if it doesn't work + self._try_init_lark() + + def _try_init_lark(self): + """Try to initialize Lark parser, gracefully handle failure.""" + try: + from lark import Lark + from lark.indenter import Indenter + + class PseudocodeIndenter(Indenter): + NL_type = '_NL' + OPEN_PAREN_types = [] + CLOSE_PAREN_types = [] + INDENT_type = '_INDENT' + DEDENT_type = '_DEDENT' + tab_len = 4 + + self._parser = Lark( + PSEUDOCODE_GRAMMAR, + parser='lalr', + postlex=PseudocodeIndenter() if self.config.use_indentation else None, + propagate_positions=True, + maybe_placeholders=True, + ) + self._lark_available = True + except Exception as e: + # Lark parsing not available, will use fallback + self._lark_available = False + self._parser = None + + def parse(self, code: str) -> ParseResult: + """ + Parse pseudocode and return a ParseResult. + + Args: + code: The pseudocode to parse + + Returns: + ParseResult with success status, AST (if successful), + errors, and warnings. + """ + errors: List[str] = [] + warnings: List[str] = [] + + # Step 1: Preprocess + try: + normalized_code, preprocess_warnings = self.preprocessor.preprocess(code) + warnings.extend(preprocess_warnings) + except Exception as e: + errors.append(f"Preprocessing failed: {str(e)}") + normalized_code = code + + # Step 2: Try Lark parsing if available and not in strict mode that requires it + ast = None + if self._lark_available and self._parser: + try: + tree = self._parser.parse(normalized_code) + # For now, we don't fully transform - just indicate success + ast = self._build_ast_from_tree(tree, normalized_code) + except Exception as e: + if self.config.strict_mode: + errors.append(f"Parse error: {str(e)}") + else: + warnings.append(f"Full parsing failed, using fallback: {str(e)[:100]}") + + # Step 3: Use fallback pattern-based parsing + if ast is None and not self.config.strict_mode: + try: + ast = self._parse_fallback(normalized_code) + except Exception as e: + errors.append(f"Fallback parsing failed: {str(e)}") + + return ParseResult( + success=ast is not None, + ast=ast, + errors=errors, + warnings=warnings, + normalized_code=normalized_code + ) + + def _build_ast_from_tree(self, tree, code: str) -> Optional[ProgramNode]: + """Build AST from Lark parse tree.""" + # Simplified AST building - mainly for structure detection + try: + functions = [] + statements = [] + + for child in tree.children: + if child is None: + continue + if hasattr(child, 'data'): + if child.data == 'function_def': + func = self._extract_function(child) + if func: + functions.append(func) + elif child.data in ('for_stmt', 'while_stmt', 'if_stmt', 'repeat_stmt'): + stmt = self._extract_statement(child) + if stmt: + statements.append(stmt) + + global_block = BlockNode(statements=statements) if statements else None + return ProgramNode(functions=functions, global_statements=global_block) + except Exception: + return None + + def _extract_function(self, node) -> Optional[FunctionNode]: + """Extract function from parse tree node.""" + try: + name = "" + for child in node.children: + if hasattr(child, 'type') and child.type == 'NAME': + name = str(child) + break + return FunctionNode(name=name, parameters=[], body=None) + except Exception: + return None + + def _extract_statement(self, node) -> Optional[Any]: + """Extract statement from parse tree node.""" + try: + if node.data == 'for_stmt': + return LoopNode(loop_type=LoopType.FOR, body=BlockNode(statements=[])) + elif node.data == 'while_stmt': + return LoopNode(loop_type=LoopType.WHILE, body=BlockNode(statements=[])) + elif node.data == 'if_stmt': + return ConditionalNode(then_branch=BlockNode(statements=[])) + except Exception: + pass + return None + + def _parse_fallback(self, code: str) -> ProgramNode: + """ + Fallback parsing using pattern detection. + + This method uses regex patterns to detect loops, conditionals, + and functions when the full grammar fails. + """ + functions: List[FunctionNode] = [] + statements = [] + + lines = code.split('\n') + indent_unit = self.preprocessor.detect_indentation_style(code) + + i = 0 + while i < len(lines): + line = lines[i] + stripped = line.strip() + + if not stripped: + i += 1 + continue + + indent_level = self.preprocessor.get_indent_level(line, indent_unit) + + # Detect function definitions (with optional { at end) + func_match = re.match( + r'^(function|algorithm|procedure|def)\s+(\w+)\s*\([^)]*\)\s*\{?', + stripped, re.IGNORECASE + ) + if func_match: + func_name = func_match.group(2) + body_lines, end_idx = self._collect_block(lines, i + 1, indent_level, indent_unit) + func = FunctionNode( + name=func_name, + parameters=[], + body=self._parse_block_fallback(body_lines, indent_unit) + ) + functions.append(func) + i = end_idx + continue + + # Detect loops + loop_node = self._detect_loop(stripped, lines, i, indent_level, indent_unit) + if loop_node: + statements.append(loop_node[0]) + i = loop_node[1] + continue + + # Detect conditionals + cond_node = self._detect_conditional(stripped, lines, i, indent_level, indent_unit) + if cond_node: + statements.append(cond_node[0]) + i = cond_node[1] + continue + + # Other statement - skip + i += 1 + + global_block = BlockNode(statements=statements) if statements else None + return ProgramNode(functions=functions, global_statements=global_block) + + def _collect_block(self, lines: List[str], start_idx: int, + base_indent: int, indent_unit: int) -> Tuple[List[str], int]: + """Collect lines belonging to a block. + + Supports three block styles: + 1. Indentation-based (Python-like) + 2. END keyword-based (END IF, END FOR, etc.) + 3. Curly brace-based ({ ... }) + """ + block_lines = [] + i = start_idx + brace_count = 0 + + # Check if the block starts with an opening brace + if i < len(lines): + first_line = lines[i].strip() + if first_line == '{' or first_line.endswith('{'): + brace_count = 1 + # If just '{', skip it; if 'DO {', include content after + if first_line == '{': + i += 1 + else: + # Remove the trailing brace from this line + block_lines.append(first_line[:-1].strip()) + i += 1 + + while i < len(lines): + line = lines[i] + stripped = line.strip() + + if not stripped: + block_lines.append(line) + i += 1 + continue + + # Handle curly brace blocks + if brace_count > 0: + brace_count += stripped.count('{') - stripped.count('}') + if brace_count <= 0: + # Remove trailing } if present + if stripped == '}': + i += 1 + else: + block_lines.append(stripped.rstrip('}').strip()) + i += 1 + break + block_lines.append(line) + i += 1 + continue + + current_indent = self.preprocessor.get_indent_level(line, indent_unit) + + # End markers (keyword-based) + if re.match(r'^(end\b|endif\b|endfor\b|endwhile\b|done\b|\})', stripped, re.IGNORECASE): + i += 1 + break + + if current_indent <= base_indent and i > start_idx: + break + + block_lines.append(line) + i += 1 + + return block_lines, i + + def _parse_block_fallback(self, lines: List[str], indent_unit: int) -> BlockNode: + """Parse a block of lines into a BlockNode.""" + statements = [] + + i = 0 + while i < len(lines): + line = lines[i] + stripped = line.strip() + + if not stripped: + i += 1 + continue + + indent_level = self.preprocessor.get_indent_level(line, indent_unit) + + loop_node = self._detect_loop(stripped, lines, i, indent_level, indent_unit) + if loop_node: + statements.append(loop_node[0]) + i = loop_node[1] + continue + + cond_node = self._detect_conditional(stripped, lines, i, indent_level, indent_unit) + if cond_node: + statements.append(cond_node[0]) + i = cond_node[1] + continue + + i += 1 + + return BlockNode(statements=statements) + + def _detect_loop(self, line: str, lines: List[str], idx: int, + indent_level: int, indent_unit: int) -> Optional[Tuple[LoopNode, int]]: + """Detect and parse a loop.""" + + # FOR loop: for i = 1 to n (with optional DO or { at end) + for_match = re.match( + r'for\s+(\w+)\s*[=:]\s*(\w+)\s+to\s+(\w+)(?:\s+(?:do|step\s+\w+))?\s*\{?', + line, re.IGNORECASE + ) + if for_match: + iterator = VariableNode(name=for_match.group(1)) + start = self._parse_simple_expr(for_match.group(2)) + end = self._parse_simple_expr(for_match.group(3)) + + body_lines, next_idx = self._collect_block(lines, idx + 1, indent_level, indent_unit) + body = self._parse_block_fallback(body_lines, indent_unit) + estimated = self._estimate_iterations(for_match.group(2), for_match.group(3)) + + return LoopNode( + loop_type=LoopType.FOR, + iterator=iterator, + start=start, + end=end, + body=body, + estimated_iterations=estimated, + nesting_level=indent_level + ), next_idx + + # WHILE loop (with optional DO or { at end) + while_match = re.match(r'while\s+(.+?)(?:\s+do)?\s*\{?$', line, re.IGNORECASE) + if while_match: + body_lines, next_idx = self._collect_block(lines, idx + 1, indent_level, indent_unit) + body = self._parse_block_fallback(body_lines, indent_unit) + + return LoopNode( + loop_type=LoopType.WHILE, + body=body, + estimated_iterations="unknown", + nesting_level=indent_level + ), next_idx + + # FOR EACH loop (with optional { at end) + foreach_match = re.match( + r'for\s+(?:each\s+)?(\w+)\s+in\s+(\w+)(?:\s+do)?\s*\{?', + line, re.IGNORECASE + ) + if foreach_match: + iterator = VariableNode(name=foreach_match.group(1)) + collection = VariableNode(name=foreach_match.group(2)) + + body_lines, next_idx = self._collect_block(lines, idx + 1, indent_level, indent_unit) + body = self._parse_block_fallback(body_lines, indent_unit) + + return LoopNode( + loop_type=LoopType.FOR_EACH, + iterator=iterator, + collection=collection, + body=body, + estimated_iterations="n", + nesting_level=indent_level + ), next_idx + + # REPEAT loop + repeat_match = re.match(r'repeat\b', line, re.IGNORECASE) + if repeat_match: + body_lines, next_idx = self._collect_block(lines, idx + 1, indent_level, indent_unit) + body = self._parse_block_fallback(body_lines, indent_unit) + + return LoopNode( + loop_type=LoopType.REPEAT_UNTIL, + body=body, + estimated_iterations="unknown", + nesting_level=indent_level + ), next_idx + + return None + + def _detect_conditional(self, line: str, lines: List[str], idx: int, + indent_level: int, indent_unit: int) -> Optional[Tuple[ConditionalNode, int]]: + """Detect and parse a conditional.""" + if_match = re.match(r'if\s+(.+?)(?:\s+then)?\s*\{?$', line, re.IGNORECASE) + if if_match: + body_lines, next_idx = self._collect_block(lines, idx + 1, indent_level, indent_unit) + then_branch = self._parse_block_fallback(body_lines, indent_unit) + + else_branch = None + if next_idx < len(lines): + else_line = lines[next_idx].strip().lower() + if else_line.startswith('else'): + else_lines, next_idx = self._collect_block(lines, next_idx + 1, indent_level, indent_unit) + else_branch = self._parse_block_fallback(else_lines, indent_unit) + + return ConditionalNode( + then_branch=then_branch, + else_branch=else_branch + ), next_idx + + return None + + def _parse_simple_expr(self, expr_str: str) -> Any: + """Parse a simple expression.""" + expr_str = expr_str.strip() + try: + if '.' in expr_str: + return LiteralNode(value=float(expr_str), literal_type="float") + return LiteralNode(value=int(expr_str), literal_type="int") + except ValueError: + return VariableNode(name=expr_str) + + def _estimate_iterations(self, start: str, end: str) -> str: + """Estimate number of iterations.""" + start = start.strip().lower() + end = end.strip().lower() + + if start in ('0', '1') and end in ('n', 'len', 'length', 'size'): + return "n" + if start in ('0', '1') and 'n-' in end: + return "n" + if start in ('0', '1') and end.startswith('n/'): + return "n/2" + if start in ('0', '1') and 'log' in end: + return "log(n)" + + try: + s = int(start) + e = int(end) + return str(e - s + 1) + except ValueError: + pass + + if end.isalpha(): + return end + return "n" + + def detect_structure(self, code: str) -> dict: + """ + Detect high-level structure without full parsing. + """ + code_lower = code.lower() + lines = code.split('\n') + + # Count loop keywords, excluding END keywords like "END FOR", "ENDFOR", etc. + # First count raw matches, then subtract END matches + loop_count = 0 + for line in lines: + line_lower = line.strip().lower() + # Skip end keywords + if line_lower.startswith('end') or line_lower.startswith('done'): + continue + # Count loop starts + if re.match(r'^for\b', line_lower): + loop_count += 1 + elif re.match(r'^while\b', line_lower): + loop_count += 1 + elif re.match(r'^repeat\b', line_lower): + loop_count += 1 + elif re.match(r'^loop\b', line_lower): + loop_count += 1 + + max_nesting = 0 + indent_unit = self.preprocessor.detect_indentation_style(code) + + # Track loop nesting by counting active loops based on keywords + current_nesting = 0 + for line in lines: + stripped = line.strip().lower() + if not stripped: + continue + + # Check if this line starts a loop (also check for opening brace) + if any(stripped.startswith(kw) for kw in ['for ', 'for(', 'while ', 'while(', 'repeat']): + current_nesting += 1 + max_nesting = max(max_nesting, current_nesting) + # If the line ends with {, the brace is part of this loop start + # (already counted above) + + # Check if this line ends a loop block (END keywords or closing brace) + if (stripped.startswith('end for') or stripped.startswith('endfor') or + stripped.startswith('end while') or stripped.startswith('endwhile') or + stripped == 'done' or stripped.startswith('until ') or + stripped == '}'): + current_nesting = max(0, current_nesting - 1) + + has_recursion = False + func_match = re.search(r'(function|algorithm|def)\s+(\w+)', code_lower) + if func_match: + func_name = func_match.group(2) + call_pattern = rf'\b{func_name}\s*\(' + calls = re.findall(call_pattern, code_lower) + has_recursion = len(calls) > 1 + + has_conditionals = bool(re.search(r'\bif\b', code_lower)) + + return { + 'has_loops': loop_count > 0, + 'has_nested_loops': max_nesting > 1, # Only true if loops are actually nested + 'has_recursion': has_recursion, + 'loop_count': loop_count, + 'max_nesting': max_nesting, + 'has_conditionals': has_conditionals, + } diff --git a/evaluation_function/parser/preprocessor.py b/evaluation_function/parser/preprocessor.py new file mode 100644 index 0000000..39c308e --- /dev/null +++ b/evaluation_function/parser/preprocessor.py @@ -0,0 +1,330 @@ +""" +Preprocessor for normalizing pseudocode syntax variations. + +This module handles the diverse ways students write pseudocode by normalizing: +- Keywords (FOR/for/For → for) +- Assignment operators (=, :=, ←, <- → =) +- Comparison operators +- Whitespace and indentation +- Common typos and variations +""" + +import re +from typing import List, Tuple, Optional +from dataclasses import dataclass + + +@dataclass +class PreprocessorConfig: + """Configuration for the preprocessor.""" + normalize_case: bool = True + normalize_operators: bool = True + normalize_whitespace: bool = True + fix_common_typos: bool = True + preserve_strings: bool = True + tab_size: int = 4 + + +class Preprocessor: + """ + Normalizes pseudocode to a standard format for parsing. + + Handles variations in: + - Loop keywords: FOR, for, For, LOOP, loop + - Conditionals: IF, if, THEN, then, ELSE, else + - Assignment: =, :=, ←, <- + - Comparisons: ==, =, ≤, <=, ≥, >=, ≠, !=, <> + - Keywords: AND, and, &&, OR, or, ||, NOT, not, ! + - Function definitions: FUNCTION, function, ALGORITHM, algorithm, PROCEDURE + - Return: RETURN, return, RETURNS + - Ranges: TO, to, DOWNTO, downto, .. + """ + + # Keyword mappings (normalized form → variations) + # Note: algorithm is kept separate from function + KEYWORD_MAPPINGS = { + # Loop keywords + "for": ["FOR", "For", "LOOP", "loop", "Loop"], + "while": ["WHILE", "While", "WHILST", "whilst"], + "do": ["DO", "Do"], + "end": ["END", "End", "ENDFOR", "endfor", "ENDWHILE", "endwhile", + "ENDIF", "endif", "END IF", "end if", "END FOR", "end for", + "END WHILE", "end while", "DONE", "done"], + "repeat": ["REPEAT", "Repeat"], + "until": ["UNTIL", "Until"], + "to": ["TO", "To"], + "downto": ["DOWNTO", "Downto", "DOWN TO", "down to"], + "step": ["STEP", "Step", "BY", "by"], + "in": ["IN", "In"], + "each": ["EACH", "Each"], + + # Conditional keywords + "if": ["IF", "If"], + "then": ["THEN", "Then"], + "else": ["ELSE", "Else"], + "elif": ["ELIF", "Elif", "ELSEIF", "elseif", "ELSE IF", "else if", "elsif", "ELSIF"], + + # Logical operators (word forms only - && and || handled separately) + "and": ["AND", "And"], + "or": ["OR", "Or"], + "not": ["NOT", "Not"], + + # Function keywords - algorithm and function are separate! + "function": ["FUNCTION", "Function", "FUNC", "func"], + "algorithm": ["ALGORITHM", "Algorithm"], + "procedure": ["PROCEDURE", "Procedure"], + "def": ["DEF"], + "return": ["RETURN", "Return", "RETURNS", "returns"], + + # Boolean literals + "true": ["TRUE", "True"], + "false": ["FALSE", "False"], + + # Other keywords + "null": ["NULL", "Null", "NIL", "nil", "NONE", "None", "none"], + "print": ["PRINT", "Print", "OUTPUT", "output", "WRITE", "write", "DISPLAY", "display"], + "input": ["INPUT", "Input", "READ", "read"], + "swap": ["SWAP", "Swap"], + "call": ["CALL", "Call"], + } + + # Operator mappings (symbols that aren't word characters) + OPERATOR_MAPPINGS = { + "&&": "and", + "||": "or", + "&": "and", + "|": "or", + } + + # Assignment operator mappings + ASSIGNMENT_OPERATORS = { + ":=": "=", + "←": "=", + "<-": "=", + "⟵": "=", + } + + # Comparison operator mappings + COMPARISON_OPERATORS = { + "≤": "<=", + "≥": ">=", + "≠": "!=", + "<>": "!=", + } + + # Common typos + TYPO_FIXES = { + "whlie": "while", + "wihle": "while", + "fro": "for", + "fo": "for", + "eles": "else", + "esle": "else", + "retrun": "return", + "reutrn": "return", + "fucntion": "function", + "funtion": "function", + "funciton": "function", + "algoritm": "algorithm", + "algortihm": "algorithm", + "pritn": "print", + "pirnt": "print", + } + + def __init__(self, config: Optional[PreprocessorConfig] = None): + self.config = config or PreprocessorConfig() + self._build_patterns() + + def _build_patterns(self): + """Build regex patterns for normalization.""" + # Build keyword pattern (case-insensitive word boundaries) + self._keyword_map = {} + for normalized, variations in self.KEYWORD_MAPPINGS.items(): + for var in variations: + self._keyword_map[var.lower()] = normalized + self._keyword_map[normalized] = normalized + + # Build typo pattern + if self.config.fix_common_typos: + self._typo_map = {k.lower(): v for k, v in self.TYPO_FIXES.items()} + + def preprocess(self, code: str) -> Tuple[str, List[str]]: + """ + Preprocess pseudocode and return normalized version. + + Args: + code: Raw pseudocode string + + Returns: + Tuple of (normalized_code, warnings) + """ + warnings = [] + + # Preserve string literals + strings = [] + if self.config.preserve_strings: + code, strings = self._extract_strings(code) + + # Normalize line endings + code = code.replace('\r\n', '\n').replace('\r', '\n') + + # Fix common typos + if self.config.fix_common_typos: + code, typo_warnings = self._fix_typos(code) + warnings.extend(typo_warnings) + + # Normalize operators (including && and ||) + if self.config.normalize_operators: + code = self._normalize_operators(code) + + # Normalize keywords + if self.config.normalize_case: + code = self._normalize_keywords(code) + + # Normalize whitespace + if self.config.normalize_whitespace: + code = self._normalize_whitespace(code) + + # Restore string literals + if self.config.preserve_strings: + code = self._restore_strings(code, strings) + + return code, warnings + + def _extract_strings(self, code: str) -> Tuple[str, List[str]]: + """Extract string literals and replace with placeholders.""" + strings = [] + + def replace_string(match): + strings.append(match.group(0)) + return f"__STRING_{len(strings) - 1}__" + + # Match both single and double quoted strings + pattern = r'"[^"\\]*(?:\\.[^"\\]*)*"|\'[^\'\\]*(?:\\.[^\'\\]*)*\'' + code = re.sub(pattern, replace_string, code) + + return code, strings + + def _restore_strings(self, code: str, strings: List[str]) -> str: + """Restore string literals from placeholders.""" + for i, s in enumerate(strings): + code = code.replace(f"__STRING_{i}__", s) + return code + + def _fix_typos(self, code: str) -> Tuple[str, List[str]]: + """Fix common typos in keywords.""" + warnings = [] + + def fix_word(match): + word = match.group(0) + lower = word.lower() + if lower in self._typo_map: + fixed = self._typo_map[lower] + warnings.append(f"Fixed typo: '{word}' → '{fixed}'") + return fixed + return word + + # Match whole words only + code = re.sub(r'\b[a-zA-Z]+\b', fix_word, code) + + return code, warnings + + def _normalize_operators(self, code: str) -> str: + """Normalize assignment, comparison, and logical operators.""" + # First handle && and || (before other processing might interfere) + # These need to be replaced with word equivalents + for old, new in self.OPERATOR_MAPPINGS.items(): + # Use word boundaries to avoid partial replacements + code = re.sub(re.escape(old), f' {new} ', code) + + # Assignment operators (must be done before comparison) + for old, new in self.ASSIGNMENT_OPERATORS.items(): + code = code.replace(old, new) + + # Handle single = that should be == (in comparisons) + # This is tricky - we need context. For now, handle obvious cases. + # = after if/while/until and before then/do should be == + code = re.sub( + r'(if|while|until)\s+([^=\n]+)\s+=\s+([^=\n]+)\s+(then|do)', + r'\1 \2 == \3 \4', + code, + flags=re.IGNORECASE + ) + + # Unicode comparison operators + for old, new in self.COMPARISON_OPERATORS.items(): + code = code.replace(old, new) + + return code + + def _normalize_keywords(self, code: str) -> str: + """Normalize keyword case to lowercase standard form.""" + + def normalize_word(match): + word = match.group(0) + lower = word.lower() + return self._keyword_map.get(lower, word) + + # Match whole words only + code = re.sub(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', normalize_word, code) + + return code + + def _normalize_whitespace(self, code: str) -> str: + """Normalize whitespace and indentation.""" + lines = code.split('\n') + normalized_lines = [] + + for line in lines: + # Convert tabs to spaces + line = line.replace('\t', ' ' * self.config.tab_size) + + # Remove trailing whitespace + line = line.rstrip() + + # Normalize multiple spaces (except leading) + if line: + leading = len(line) - len(line.lstrip()) + content = ' '.join(line.split()) + line = ' ' * leading + content + + normalized_lines.append(line) + + # Remove multiple blank lines + code = '\n'.join(normalized_lines) + code = re.sub(r'\n{3,}', '\n\n', code) + + return code.strip() + + def detect_indentation_style(self, code: str) -> int: + """ + Detect the indentation unit used in the code. + + Returns: + Number of spaces per indent level (common values: 2, 4) + """ + lines = code.split('\n') + indents = [] + + for line in lines: + if line.strip(): # Non-empty line + leading = len(line) - len(line.lstrip()) + if leading > 0: + indents.append(leading) + + if not indents: + return 4 # Default + + # Find GCD of all indentations + from math import gcd + from functools import reduce + + indent_unit = reduce(gcd, indents) + return max(indent_unit, 2) # At least 2 spaces + + def get_indent_level(self, line: str, indent_unit: int = 4) -> int: + """Get the indentation level of a line.""" + if not line.strip(): + return 0 + leading = len(line) - len(line.lstrip()) + return leading // indent_unit diff --git a/evaluation_function/preview.py b/evaluation_function/preview.py index a47bcac..f341bb2 100755 --- a/evaluation_function/preview.py +++ b/evaluation_function/preview.py @@ -1,30 +1,181 @@ -from typing import Any +""" +Preview function for pseudocode complexity analysis. + +This module provides a preview function that: +1. Validates that the pseudocode can be parsed +2. Analyzes the complexity of the code +3. Returns a preview showing the detected structure and complexity + +The preview helps students verify their code before submission. +""" + +from typing import Any, Dict, List, Optional from lf_toolkit.preview import Result, Params, Preview +from .parser.parser import PseudocodeParser, ParserConfig +from .analyzer.complexity_analyzer import ComplexityAnalyzer +from .analyzer.feedback_generator import FeedbackGenerator, FeedbackLevel + + def preview_function(response: Any, params: Params) -> Result: """ - Function used to preview a student response. - --- - The handler function passes three arguments to preview_function(): - - - `response` which are the answers provided by the student. - - `params` which are any extra parameters that may be useful, - e.g., error tolerances. + Preview a student's pseudocode submission. - The output of this function is what is returned as the API response - and therefore must be JSON-encodable. It must also conform to the - response schema. + This function parses and analyzes the pseudocode to provide immediate + feedback on whether the code is valid and what complexity is detected. - Any standard python library may be used, as well as any package - available on pip (provided it is added to requirements.txt). + Args: + response: The student's pseudocode (string) + params: Additional parameters for configuration - The way you wish to structure you code (all in this function, or - split into many) is entirely up to you. + Returns: + Result containing a Preview with: + - latex: Formatted complexity result + - feedback: Detailed feedback about parsing and analysis """ - try: - return Result(preview=Preview(sympy=response)) - except FeedbackException as e: - return Result(preview=Preview(feedback=str(e))) + # Extract pseudocode from response + if isinstance(response, dict): + pseudocode = response.get('pseudocode', response.get('code', '')) + elif isinstance(response, str): + pseudocode = response + else: + return Result(preview=Preview( + feedback="Invalid response format. Please provide pseudocode as a string." + )) + + if not pseudocode or not pseudocode.strip(): + return Result(preview=Preview( + feedback="Please enter your pseudocode to see a preview." + )) + + # Parse the pseudocode + parser = PseudocodeParser() + parse_result = parser.parse(pseudocode) + + if not parse_result.success: + error_msg = "Failed to parse pseudocode." + if parse_result.errors: + error_msg += "\n\nErrors:\n" + "\n".join(f"- {e}" for e in parse_result.errors) + if parse_result.warnings: + error_msg += "\n\nWarnings:\n" + "\n".join(f"- {w}" for w in parse_result.warnings) + return Result(preview=Preview(feedback=error_msg)) + + # Analyze complexity + analyzer = ComplexityAnalyzer() + analysis = analyzer.analyze(pseudocode, parse_result.ast) + + # Detect structure + structure = parser.detect_structure(pseudocode) + + # Generate preview content + preview_content = _generate_preview_content(analysis, structure, parse_result) + + return Result(preview=Preview( + latex=f"\\text{{Time Complexity: }} {_latex_complexity(analysis.time_complexity.value)}", + feedback=preview_content + )) + except Exception as e: - return Result(preview=Preview(feedback=str(e))) + return Result(preview=Preview( + feedback=f"An error occurred during preview: {str(e)}" + )) + + +def _generate_preview_content(analysis, structure: Dict, parse_result) -> str: + """Generate detailed preview content.""" + lines = [] + + # Header + lines.append("=" * 50) + lines.append("PSEUDOCODE ANALYSIS PREVIEW") + lines.append("=" * 50) + lines.append("") + + # Parsing status + lines.append("✓ Parsing: Successful") + if parse_result.warnings: + for w in parse_result.warnings[:3]: # Show max 3 warnings + lines.append(f" ⚠ {w}") + lines.append("") + + # Detected structure + lines.append("-" * 50) + lines.append("DETECTED STRUCTURE") + lines.append("-" * 50) + + structure_items = [] + if structure.get('has_loops'): + loop_count = structure.get('loop_count', 0) + structure_items.append(f"• Loops: {loop_count}") + if structure.get('has_nested_loops'): + max_nesting = structure.get('max_nesting', 0) + structure_items.append(f"• Nested loops: Yes (depth {max_nesting})") + if structure.get('has_recursion'): + structure_items.append("• Recursion: Yes") + if structure.get('has_conditionals'): + structure_items.append("• Conditionals: Yes") + + if structure_items: + lines.extend(structure_items) + else: + lines.append("• No loops or recursion detected (constant time)") + lines.append("") + + # Complexity result + lines.append("-" * 50) + lines.append("COMPLEXITY ANALYSIS") + lines.append("-" * 50) + lines.append(f"Time Complexity: {analysis.time_complexity.value}") + lines.append(f"Space Complexity: {analysis.space_complexity.value}") + lines.append("") + + # Loop details + if analysis.loops: + lines.append("-" * 50) + lines.append("LOOP DETAILS") + lines.append("-" * 50) + for i, loop in enumerate(analysis.loops, 1): + lines.append(f"{i}. {loop.get_description()}") + lines.append(f" Iterations: {loop.iterations}") + lines.append(f" Contribution: {loop.complexity.value}") + if loop.nested_loops: + for j, nested in enumerate(loop.nested_loops, 1): + lines.append(f" └─ Nested {j}: {nested.get_description()} ({nested.complexity.value})") + lines.append("") + + # Recursion details + if analysis.recursion: + rec = analysis.recursion + lines.append("-" * 50) + lines.append("RECURSION DETAILS") + lines.append("-" * 50) + lines.append(f"Function: {rec.function_name}()") + lines.append(f"Branching factor: {rec.branching_factor}") + lines.append(f"Reduction pattern: {rec.reduction_pattern}") + lines.append(f"Recurrence: {rec.recurrence}") + lines.append("") + + # Confidence + lines.append("-" * 50) + confidence_pct = int(analysis.confidence * 100) + lines.append(f"Analysis confidence: {confidence_pct}%") + + return "\n".join(lines) + + +def _latex_complexity(complexity: str) -> str: + """Convert complexity to LaTeX format.""" + replacements = { + "O(1)": "O(1)", + "O(log n)": "O(\\log n)", + "O(√n)": "O(\\sqrt{n})", + "O(n)": "O(n)", + "O(n log n)": "O(n \\log n)", + "O(n²)": "O(n^2)", + "O(n³)": "O(n^3)", + "O(n^k)": "O(n^k)", + "O(2^n)": "O(2^n)", + "O(n!)": "O(n!)", + } + return replacements.get(complexity, complexity) diff --git a/evaluation_function/preview_test.py b/evaluation_function/preview_test.py deleted file mode 100755 index a8834a7..0000000 --- a/evaluation_function/preview_test.py +++ /dev/null @@ -1,29 +0,0 @@ -import unittest - -from .preview import Params, preview_function - -class TestPreviewFunction(unittest.TestCase): - """ - TestCase Class used to test the algorithm. - --- - Tests are used here to check that the algorithm written - is working as it should. - - It's best practice to write these tests first to get a - kind of 'specification' for how your algorithm should - work, and you should run these tests before committing - your code to AWS. - - Read the docs on how to use unittest here: - https://docs.python.org/3/library/unittest.html - - Use preview_function() to check your algorithm works - as it should. - """ - - def test_preview(self): - response, params = "A", Params() - result = preview_function(response, params) - - self.assertIn("preview", result) - self.assertIsNotNone(result["preview"]) diff --git a/evaluation_function/schemas/complexity.py b/evaluation_function/schemas/complexity.py index 92e04d2..0310aa1 100644 --- a/evaluation_function/schemas/complexity.py +++ b/evaluation_function/schemas/complexity.py @@ -164,11 +164,21 @@ def multiply(cls, a: "ComplexityClass", b: "ComplexityClass") -> "ComplexityClas (cls.LINEAR, cls.LINEAR): cls.QUADRATIC, (cls.LINEAR, cls.QUADRATIC): cls.CUBIC, (cls.QUADRATIC, cls.LINEAR): cls.CUBIC, + (cls.LINEAR, cls.CUBIC): cls.POLYNOMIAL, # O(n⁴) + (cls.CUBIC, cls.LINEAR): cls.POLYNOMIAL, # O(n⁴) (cls.LINEAR, cls.LOGARITHMIC): cls.LINEARITHMIC, (cls.LOGARITHMIC, cls.LINEAR): cls.LINEARITHMIC, (cls.LINEAR, cls.LINEARITHMIC): cls.POLYNOMIAL, # O(n² log n) ≈ polynomial (cls.LINEARITHMIC, cls.LINEAR): cls.POLYNOMIAL, (cls.QUADRATIC, cls.QUADRATIC): cls.POLYNOMIAL, # O(n⁴) + (cls.QUADRATIC, cls.CUBIC): cls.POLYNOMIAL, # O(n⁵) + (cls.CUBIC, cls.QUADRATIC): cls.POLYNOMIAL, # O(n⁵) + (cls.CUBIC, cls.CUBIC): cls.POLYNOMIAL, # O(n⁶) + (cls.LINEAR, cls.POLYNOMIAL): cls.POLYNOMIAL, # Still polynomial + (cls.POLYNOMIAL, cls.LINEAR): cls.POLYNOMIAL, # Still polynomial + (cls.QUADRATIC, cls.POLYNOMIAL): cls.POLYNOMIAL, + (cls.POLYNOMIAL, cls.QUADRATIC): cls.POLYNOMIAL, + (cls.POLYNOMIAL, cls.POLYNOMIAL): cls.POLYNOMIAL, } result = rules.get((a, b)) diff --git a/evaluation_function/tests/__init__.py b/evaluation_function/tests/__init__.py new file mode 100644 index 0000000..64d2823 --- /dev/null +++ b/evaluation_function/tests/__init__.py @@ -0,0 +1,10 @@ +""" +Test suite for the Algorithm Complexity Evaluation Function. + +This package contains comprehensive tests for: +- Preprocessor: Syntax normalization +- Parser: Pseudocode parsing +- AST Builder: Parse tree transformation +- Complexity Analysis: Complexity detection and calculation +- Integration: End-to-end tests +""" diff --git a/evaluation_function/tests/conftest.py b/evaluation_function/tests/conftest.py new file mode 100644 index 0000000..cd902c2 --- /dev/null +++ b/evaluation_function/tests/conftest.py @@ -0,0 +1,370 @@ +""" +Pytest configuration and fixtures for the test suite. +""" + +import pytest +from typing import List, Dict, Any + +from ..parser.preprocessor import Preprocessor, PreprocessorConfig +from ..parser.parser import PseudocodeParser, ParserConfig +from ..schemas.complexity import ComplexityClass + + +# ============================================================================= +# Parser Fixtures +# ============================================================================= + +@pytest.fixture +def preprocessor() -> Preprocessor: + """Create a default preprocessor instance.""" + return Preprocessor() + + +@pytest.fixture +def preprocessor_strict() -> Preprocessor: + """Create a preprocessor with strict settings.""" + config = PreprocessorConfig( + normalize_case=True, + normalize_operators=True, + normalize_whitespace=True, + fix_common_typos=False, # Strict: don't fix typos + preserve_strings=True, + ) + return Preprocessor(config) + + +@pytest.fixture +def parser() -> PseudocodeParser: + """Create a default parser instance.""" + return PseudocodeParser() + + +@pytest.fixture +def parser_strict() -> PseudocodeParser: + """Create a parser with strict mode enabled.""" + config = ParserConfig(strict_mode=True) + return PseudocodeParser(config) + + +# ============================================================================= +# Sample Pseudocode Fixtures +# ============================================================================= + +@pytest.fixture +def simple_for_loop() -> str: + """Simple FOR loop pseudocode.""" + return """FOR i = 1 TO n DO + print(i) +END FOR""" + + +@pytest.fixture +def nested_for_loops() -> str: + """Nested FOR loops pseudocode.""" + return """FOR i = 1 TO n DO + FOR j = 1 TO n DO + sum = sum + A[i][j] + END FOR +END FOR""" + + +@pytest.fixture +def triple_nested_loops() -> str: + """Triple nested loops pseudocode.""" + return """FOR i = 1 TO n DO + FOR j = 1 TO n DO + FOR k = 1 TO n DO + result = result + A[i][j][k] + END FOR + END FOR +END FOR""" + + +@pytest.fixture +def while_loop() -> str: + """Simple WHILE loop pseudocode.""" + return """WHILE i < n DO + i = i + 1 + count = count + 1 +END WHILE""" + + +@pytest.fixture +def binary_search() -> str: + """Binary search algorithm (O(log n)).""" + return """FUNCTION binarySearch(A, target, low, high) + WHILE low <= high DO + mid = (low + high) / 2 + IF A[mid] == target THEN + RETURN mid + ELSE IF A[mid] < target THEN + low = mid + 1 + ELSE + high = mid - 1 + END IF + END WHILE + RETURN -1 +END FUNCTION""" + + +@pytest.fixture +def bubble_sort() -> str: + """Bubble sort algorithm (O(n²)).""" + return """FUNCTION bubbleSort(A, n) + FOR i = 1 TO n-1 DO + FOR j = 1 TO n-i DO + IF A[j] > A[j+1] THEN + swap(A[j], A[j+1]) + END IF + END FOR + END FOR +END FUNCTION""" + + +@pytest.fixture +def recursive_fibonacci() -> str: + """Recursive Fibonacci (O(2^n)).""" + return """FUNCTION fib(n) + IF n <= 1 THEN + RETURN n + END IF + RETURN fib(n-1) + fib(n-2) +END FUNCTION""" + + +@pytest.fixture +def recursive_factorial() -> str: + """Recursive factorial (O(n)).""" + return """FUNCTION factorial(n) + IF n <= 1 THEN + RETURN 1 + END IF + RETURN n * factorial(n-1) +END FUNCTION""" + + +@pytest.fixture +def merge_sort() -> str: + """Merge sort algorithm (O(n log n)).""" + return """FUNCTION mergeSort(A, left, right) + IF left < right THEN + mid = (left + right) / 2 + mergeSort(A, left, mid) + mergeSort(A, mid+1, right) + merge(A, left, mid, right) + END IF +END FUNCTION + +FUNCTION merge(A, left, mid, right) + FOR i = left TO right DO + temp[i] = A[i] + END FOR +END FUNCTION""" + + +@pytest.fixture +def linear_search() -> str: + """Linear search algorithm (O(n)).""" + return """FUNCTION linearSearch(A, n, target) + FOR i = 1 TO n DO + IF A[i] == target THEN + RETURN i + END IF + END FOR + RETURN -1 +END FUNCTION""" + + +@pytest.fixture +def matrix_multiplication() -> str: + """Matrix multiplication (O(n³)).""" + return """FUNCTION matrixMultiply(A, B, n) + FOR i = 1 TO n DO + FOR j = 1 TO n DO + C[i][j] = 0 + FOR k = 1 TO n DO + C[i][j] = C[i][j] + A[i][k] * B[k][j] + END FOR + END FOR + END FOR + RETURN C +END FUNCTION""" + + +# ============================================================================= +# Pseudocode Style Variations Fixtures +# ============================================================================= + +@pytest.fixture +def python_style_loop() -> str: + """Python-style pseudocode.""" + return """def bubble_sort(arr): + for i in range(len(arr)): + for j in range(len(arr) - i - 1): + if arr[j] > arr[j+1]: + swap(arr[j], arr[j+1])""" + + +@pytest.fixture +def pascal_style_loop() -> str: + """Pascal-style pseudocode.""" + return """PROCEDURE BubbleSort(A: ARRAY; n: INTEGER); +BEGIN + FOR i := 1 TO n-1 DO + FOR j := 1 TO n-i DO + IF A[j] > A[j+1] THEN + SWAP(A[j], A[j+1]) + END + END + END +END""" + + +@pytest.fixture +def c_style_loop() -> str: + """C-style pseudocode.""" + return """function bubbleSort(A[], n) { + for (i = 0; i < n-1; i++) { + for (j = 0; j < n-i-1; j++) { + if (A[j] > A[j+1]) { + swap(A[j], A[j+1]); + } + } + } +}""" + + +@pytest.fixture +def mixed_case_keywords() -> str: + """Pseudocode with mixed case keywords.""" + return """FOR i = 1 To n DO + While j < n Do + IF condition Then + j = j + 1 + ELSE + j = j - 1 + End If + End While +End For""" + + +@pytest.fixture +def unicode_operators() -> str: + """Pseudocode with unicode operators.""" + return """FOR i ← 1 TO n DO + IF A[i] ≤ max AND A[i] ≥ min THEN + IF A[i] ≠ target THEN + count ← count + 1 + END IF + END IF +END FOR""" + + +@pytest.fixture +def typos_in_keywords() -> str: + """Pseudocode with common typos.""" + return """FUCNTION test(n) + WHLIE i < n DO + i = i + 1 + END WHLIE + RETRUN result +END FUCNTION""" + + +# ============================================================================= +# Edge Case Fixtures +# ============================================================================= + +@pytest.fixture +def empty_function() -> str: + """Empty function body.""" + return """FUNCTION emptyFunc() +END FUNCTION""" + + +@pytest.fixture +def deeply_nested() -> str: + """Deeply nested structure.""" + return """FOR i = 1 TO n DO + FOR j = 1 TO n DO + FOR k = 1 TO n DO + FOR l = 1 TO n DO + FOR m = 1 TO n DO + x = x + 1 + END FOR + END FOR + END FOR + END FOR +END FOR""" + + +@pytest.fixture +def multiple_functions() -> str: + """Multiple function definitions.""" + return """FUNCTION helper(x) + RETURN x * 2 +END FUNCTION + +FUNCTION main(n) + FOR i = 1 TO n DO + result = helper(i) + END FOR + RETURN result +END FUNCTION""" + + +@pytest.fixture +def foreach_loop() -> str: + """For-each loop pseudocode.""" + return """FOR EACH item IN collection DO + process(item) +END FOR""" + + +@pytest.fixture +def repeat_until_loop() -> str: + """Repeat-until loop pseudocode.""" + return """REPEAT + x = x + 1 +UNTIL x >= n""" + + +# ============================================================================= +# Expected Complexity Fixtures +# ============================================================================= + +@pytest.fixture +def complexity_test_cases() -> List[Dict[str, Any]]: + """Test cases with expected complexities.""" + return [ + { + "name": "constant", + "code": "x = 1\ny = 2\nz = x + y", + "expected_time": ComplexityClass.CONSTANT, + "expected_space": ComplexityClass.CONSTANT, + }, + { + "name": "single_loop", + "code": "FOR i = 1 TO n DO\n print(i)\nEND FOR", + "expected_time": ComplexityClass.LINEAR, + "expected_space": ComplexityClass.CONSTANT, + }, + { + "name": "nested_loops", + "code": "FOR i = 1 TO n DO\n FOR j = 1 TO n DO\n x = x + 1\n END FOR\nEND FOR", + "expected_time": ComplexityClass.QUADRATIC, + "expected_space": ComplexityClass.CONSTANT, + }, + { + "name": "triple_nested", + "code": "FOR i = 1 TO n DO\n FOR j = 1 TO n DO\n FOR k = 1 TO n DO\n x = 1\n END FOR\n END FOR\nEND FOR", + "expected_time": ComplexityClass.CUBIC, + "expected_space": ComplexityClass.CONSTANT, + }, + { + "name": "logarithmic", + "code": "WHILE n > 1 DO\n n = n / 2\nEND WHILE", + "expected_time": ComplexityClass.LOGARITHMIC, + "expected_space": ComplexityClass.CONSTANT, + }, + ] diff --git a/evaluation_function/tests/test_analyzer.py b/evaluation_function/tests/test_analyzer.py new file mode 100644 index 0000000..d0791e8 --- /dev/null +++ b/evaluation_function/tests/test_analyzer.py @@ -0,0 +1,1123 @@ +""" +Comprehensive tests for the Complexity Analyzer module. + +Tests cover: +- Loop detection and analysis +- Recursion detection and complexity +- Nested loop handling +- AST-based and pattern-based analysis +- Feedback generation +- Edge cases and various complexity classes +""" + +import pytest +from evaluation_function.analyzer.complexity_analyzer import ( + ComplexityAnalyzer, + AnalysisResult, + LoopInfo, + RecursionInfo, +) +from evaluation_function.analyzer.feedback_generator import ( + FeedbackGenerator, + DetailedFeedback, + FeedbackLevel, + FeedbackSection, +) +from evaluation_function.schemas.complexity import ComplexityClass +from evaluation_function.schemas.ast_nodes import ( + ProgramNode, + FunctionNode, + BlockNode, + LoopNode, + ConditionalNode, + VariableNode, + LiteralNode, + LoopType, +) + + +# ============================================================================ +# Fixtures +# ============================================================================ + +@pytest.fixture +def analyzer(): + """Create a fresh analyzer for each test.""" + return ComplexityAnalyzer() + + +@pytest.fixture +def feedback_generator(): + """Create a feedback generator.""" + return FeedbackGenerator() + + +# ============================================================================ +# LoopInfo Tests +# ============================================================================ + +class TestLoopInfo: + """Tests for LoopInfo dataclass.""" + + def test_for_loop_description_with_bounds(self): + """Test FOR loop description with start and end bounds.""" + loop = LoopInfo( + loop_type="for", + iterator="i", + start_bound="1", + end_bound="n", + step="1", + iterations="n", + complexity=ComplexityClass.LINEAR, + nesting_level=0 + ) + desc = loop.get_description() + assert "FOR loop" in desc + assert "i" in desc + assert "1" in desc + assert "n" in desc + + def test_for_loop_description_without_bounds(self): + """Test FOR loop description without bounds.""" + loop = LoopInfo( + loop_type="for", + iterator="i", + start_bound=None, + end_bound=None, + step=None, + iterations="n", + complexity=ComplexityClass.LINEAR, + nesting_level=0 + ) + desc = loop.get_description() + assert "FOR loop" in desc + assert "iterator" in desc.lower() or "i" in desc + + def test_foreach_loop_description(self): + """Test FOR-EACH loop description.""" + loop = LoopInfo( + loop_type="for_each", + iterator="item", + start_bound=None, + end_bound="collection", + step=None, + iterations="n", + complexity=ComplexityClass.LINEAR, + nesting_level=0 + ) + desc = loop.get_description() + assert "FOR-EACH" in desc or "for_each" in desc.lower() + + def test_while_loop_description(self): + """Test WHILE loop description.""" + loop = LoopInfo( + loop_type="while", + iterator=None, + start_bound=None, + end_bound=None, + step=None, + iterations="n", + complexity=ComplexityClass.LINEAR, + nesting_level=0 + ) + desc = loop.get_description() + assert "WHILE" in desc.upper() + + def test_repeat_loop_description(self): + """Test REPEAT-UNTIL loop description.""" + loop = LoopInfo( + loop_type="repeat", + iterator=None, + start_bound=None, + end_bound=None, + step=None, + iterations="n", + complexity=ComplexityClass.LINEAR, + nesting_level=0 + ) + desc = loop.get_description() + assert "REPEAT" in desc.upper() + + +# ============================================================================ +# RecursionInfo Tests +# ============================================================================ + +class TestRecursionInfo: + """Tests for RecursionInfo dataclass.""" + + def test_linear_recursion_description(self): + """Test linear recursion description.""" + rec = RecursionInfo( + function_name="factorial", + num_recursive_calls=1, + reduction_pattern="n-1", + branching_factor=1, + work_per_call=ComplexityClass.CONSTANT, + complexity=ComplexityClass.LINEAR, + recurrence="T(n) = T(n-1) + O(1)" + ) + desc = rec.get_description() + assert "Linear recursion" in desc + assert "factorial" in desc + + def test_divide_conquer_description(self): + """Test divide-and-conquer recursion description.""" + rec = RecursionInfo( + function_name="mergeSort", + num_recursive_calls=2, + reduction_pattern="n/2", + branching_factor=2, + work_per_call=ComplexityClass.LINEAR, + complexity=ComplexityClass.LINEARITHMIC, + recurrence="T(n) = 2T(n/2) + O(n)" + ) + desc = rec.get_description() + assert "Divide-and-conquer" in desc + assert "mergeSort" in desc + + def test_binary_recursion_description(self): + """Test binary recursion (non-divide-conquer) description.""" + rec = RecursionInfo( + function_name="fib", + num_recursive_calls=2, + reduction_pattern="n-1", + branching_factor=2, + work_per_call=ComplexityClass.CONSTANT, + complexity=ComplexityClass.EXPONENTIAL, + recurrence="T(n) = 2T(n-1) + O(1)" + ) + desc = rec.get_description() + assert "Binary recursion" in desc + assert "fib" in desc + + def test_multiple_recursion_description(self): + """Test multiple (>2) recursion description.""" + rec = RecursionInfo( + function_name="multiRec", + num_recursive_calls=3, + reduction_pattern="n-1", + branching_factor=3, + work_per_call=ComplexityClass.CONSTANT, + complexity=ComplexityClass.EXPONENTIAL, + recurrence="T(n) = 3T(n-1) + O(1)" + ) + desc = rec.get_description() + assert "Multiple recursion" in desc or "3" in desc + + +# ============================================================================ +# AnalysisResult Tests +# ============================================================================ + +class TestAnalysisResult: + """Tests for AnalysisResult dataclass.""" + + def test_get_complexity_string(self): + """Test complexity string retrieval.""" + result = AnalysisResult( + time_complexity=ComplexityClass.QUADRATIC, + space_complexity=ComplexityClass.LINEAR, + loops=[], + recursion=None, + max_nesting_depth=2, + confidence=0.9, + factors=[] + ) + assert result.get_complexity_string() == "O(n²)" + + def test_result_with_loops(self): + """Test result with loop information.""" + loop = LoopInfo( + loop_type="for", + iterator="i", + start_bound="1", + end_bound="n", + step="1", + iterations="n", + complexity=ComplexityClass.LINEAR, + nesting_level=0 + ) + result = AnalysisResult( + time_complexity=ComplexityClass.LINEAR, + space_complexity=ComplexityClass.CONSTANT, + loops=[loop], + recursion=None, + max_nesting_depth=1, + confidence=0.9, + factors=[] + ) + assert len(result.loops) == 1 + assert result.recursion is None + + +# ============================================================================ +# ComplexityAnalyzer - Pattern-Based Tests +# ============================================================================ + +class TestAnalyzerPatternBased: + """Tests for pattern-based analysis (no AST).""" + + def test_constant_complexity_no_loops(self, analyzer): + """Test constant complexity for code without loops.""" + code = """ + FUNCTION simple(x) + result = x + 1 + RETURN result + END FUNCTION + """ + result = analyzer.analyze(code) + assert result.time_complexity == ComplexityClass.CONSTANT + assert len(result.loops) == 0 + + def test_single_for_loop_linear(self, analyzer): + """Test single FOR loop gives linear complexity.""" + code = """ + FUNCTION sum(A, n) + total = 0 + FOR i = 1 TO n DO + total = total + A[i] + END FOR + RETURN total + END FUNCTION + """ + result = analyzer.analyze(code) + assert result.time_complexity == ComplexityClass.LINEAR + assert len(result.loops) == 1 + assert result.loops[0].loop_type == "for" + + def test_nested_for_loops_quadratic(self, analyzer): + """Test nested FOR loops give quadratic complexity.""" + code = """ + FUNCTION bubbleSort(A, n) + FOR i = 1 TO n DO + FOR j = 1 TO n DO + IF A[j] > A[j+1] THEN + swap(A[j], A[j+1]) + END IF + END FOR + END FOR + END FUNCTION + """ + result = analyzer.analyze(code) + assert result.time_complexity == ComplexityClass.QUADRATIC + assert result.max_nesting_depth == 2 + + def test_triple_nested_loops_cubic(self, analyzer): + """Test triple nested loops give cubic complexity.""" + code = """ + FUNCTION matrixMultiply(A, B, n) + FOR i = 1 TO n DO + FOR j = 1 TO n DO + FOR k = 1 TO n DO + C[i][j] = C[i][j] + A[i][k] * B[k][j] + END FOR + END FOR + END FOR + END FUNCTION + """ + result = analyzer.analyze(code) + assert result.time_complexity == ComplexityClass.CUBIC + assert result.max_nesting_depth == 3 + + def test_while_loop_linear(self, analyzer): + """Test WHILE loop detection.""" + code = """ + FUNCTION findElement(A, target) + i = 0 + WHILE i < n DO + IF A[i] == target THEN + RETURN i + END IF + i = i + 1 + END WHILE + RETURN -1 + END FUNCTION + """ + result = analyzer.analyze(code) + assert result.time_complexity == ComplexityClass.LINEAR + assert len(result.loops) == 1 + assert result.loops[0].loop_type == "while" + + def test_while_loop_logarithmic(self, analyzer): + """Test WHILE loop with halving gives logarithmic complexity.""" + code = """ + FUNCTION binarySearch(A, target) + low = 0 + high = n + WHILE low <= high DO + mid = (low + high) / 2 + IF A[mid] == target THEN + RETURN mid + ELSE IF A[mid] < target THEN + low = mid + 1 + ELSE + high = mid - 1 + END IF + END WHILE + RETURN -1 + END FUNCTION + """ + result = analyzer.analyze(code) + # Binary search should be detected as logarithmic + assert result.time_complexity in [ComplexityClass.LOGARITHMIC, ComplexityClass.LINEAR] + + def test_foreach_loop_linear(self, analyzer): + """Test FOR-EACH loop detection.""" + code = """ + FUNCTION printAll(collection) + FOR EACH item IN collection DO + print(item) + END FOR + END FUNCTION + """ + result = analyzer.analyze(code) + assert result.time_complexity == ComplexityClass.LINEAR + assert len(result.loops) == 1 + assert result.loops[0].loop_type == "for_each" + + def test_repeat_until_loop(self, analyzer): + """Test REPEAT-UNTIL loop detection.""" + code = """ + FUNCTION readInput() + REPEAT + input = read() + UNTIL input == "quit" + END FUNCTION + """ + result = analyzer.analyze(code) + assert len(result.loops) == 1 + assert result.loops[0].loop_type == "repeat" + + +# ============================================================================ +# ComplexityAnalyzer - Recursion Tests +# ============================================================================ + +class TestAnalyzerRecursion: + """Tests for recursion detection and analysis.""" + + def test_simple_linear_recursion(self, analyzer): + """Test simple linear recursion (factorial-like).""" + code = """ + FUNCTION factorial(n) + IF n <= 1 THEN + RETURN 1 + END IF + RETURN n * factorial(n - 1) + END FUNCTION + """ + result = analyzer.analyze(code) + assert result.recursion is not None + assert result.recursion.function_name == "factorial" + assert result.recursion.branching_factor == 1 + assert result.time_complexity == ComplexityClass.LINEAR + + def test_binary_recursion_exponential(self, analyzer): + """Test binary recursion (Fibonacci-like) is exponential.""" + code = """ + FUNCTION fib(n) + IF n <= 1 THEN + RETURN n + END IF + RETURN fib(n - 1) + fib(n - 2) + END FUNCTION + """ + result = analyzer.analyze(code) + assert result.recursion is not None + assert result.recursion.function_name == "fib" + assert result.recursion.branching_factor >= 2 + assert result.time_complexity == ComplexityClass.EXPONENTIAL + + def test_divide_and_conquer_merge_sort(self, analyzer): + """Test divide-and-conquer recursion (merge sort pattern).""" + code = """ + FUNCTION mergeSort(A, low, high) + IF low < high THEN + mid = (low + high) / 2 + mergeSort(A, low, mid) + mergeSort(A, mid + 1, high) + merge(A, low, mid, high) + END IF + END FUNCTION + """ + result = analyzer.analyze(code) + assert result.recursion is not None + assert result.recursion.function_name.lower() == "mergesort" + assert "n/2" in result.recursion.reduction_pattern + assert result.time_complexity == ComplexityClass.LINEARITHMIC + + def test_binary_search_recursion(self, analyzer): + """Test binary search recursive pattern.""" + code = """ + FUNCTION binarySearch(A, target, low, high) + IF low > high THEN + RETURN -1 + END IF + mid = (low + high) / 2 + IF A[mid] == target THEN + RETURN mid + ELSE IF A[mid] < target THEN + RETURN binarySearch(A, target, mid + 1, high) + ELSE + RETURN binarySearch(A, target, low, mid - 1) + END IF + END FUNCTION + """ + result = analyzer.analyze(code) + assert result.recursion is not None + assert result.recursion.branching_factor == 1 + assert result.time_complexity == ComplexityClass.LOGARITHMIC + + def test_space_complexity_recursive(self, analyzer): + """Test space complexity for recursive functions.""" + code = """ + FUNCTION factorial(n) + IF n <= 1 THEN + RETURN 1 + END IF + RETURN n * factorial(n - 1) + END FUNCTION + """ + result = analyzer.analyze(code) + # Linear recursion has O(n) stack depth + assert result.space_complexity == ComplexityClass.LINEAR + + def test_space_complexity_divide_conquer(self, analyzer): + """Test space complexity for divide-and-conquer.""" + code = """ + FUNCTION binarySearch(A, target, low, high) + IF low > high THEN + RETURN -1 + END IF + mid = (low + high) / 2 + RETURN binarySearch(A, target, mid + 1, high) + END FUNCTION + """ + result = analyzer.analyze(code) + # Divide-and-conquer has O(log n) stack depth + assert result.space_complexity == ComplexityClass.LOGARITHMIC + + +# ============================================================================ +# ComplexityAnalyzer - AST-Based Tests +# ============================================================================ + +class TestAnalyzerAST: + """Tests for AST-based analysis.""" + + def test_analyze_ast_single_loop(self, analyzer): + """Test AST analysis with single loop.""" + loop = LoopNode( + loop_type=LoopType.FOR, + iterator=VariableNode(name="i"), + start=LiteralNode(value=1), + end=VariableNode(name="n"), + body=BlockNode(statements=[]), + estimated_iterations="n" + ) + func = FunctionNode( + name="test", + parameters=[], + body=BlockNode(statements=[loop]) + ) + ast = ProgramNode(functions=[func]) + + result = analyzer.analyze("", ast) + assert result.time_complexity == ComplexityClass.LINEAR + assert len(result.loops) == 1 + + def test_analyze_ast_nested_loops(self, analyzer): + """Test AST analysis with nested loops.""" + inner_loop = LoopNode( + loop_type=LoopType.FOR, + iterator=VariableNode(name="j"), + start=LiteralNode(value=1), + end=VariableNode(name="n"), + body=BlockNode(statements=[]), + estimated_iterations="n" + ) + outer_loop = LoopNode( + loop_type=LoopType.FOR, + iterator=VariableNode(name="i"), + start=LiteralNode(value=1), + end=VariableNode(name="n"), + body=BlockNode(statements=[inner_loop]), + estimated_iterations="n" + ) + func = FunctionNode( + name="test", + parameters=[], + body=BlockNode(statements=[outer_loop]) + ) + ast = ProgramNode(functions=[func]) + + result = analyzer.analyze("", ast) + assert result.time_complexity == ComplexityClass.QUADRATIC + + def test_analyze_ast_global_statements(self, analyzer): + """Test AST analysis with global statements (no functions).""" + loop = LoopNode( + loop_type=LoopType.WHILE, + condition=None, + body=BlockNode(statements=[]), + estimated_iterations="n" + ) + ast = ProgramNode( + functions=[], + global_statements=BlockNode(statements=[loop]) + ) + + result = analyzer.analyze("", ast) + assert result.time_complexity == ComplexityClass.LINEAR + + def test_analyze_ast_loop_with_conditional(self, analyzer): + """Test AST analysis with loop containing conditional.""" + inner_loop = LoopNode( + loop_type=LoopType.FOR, + iterator=VariableNode(name="k"), + start=LiteralNode(value=1), + end=VariableNode(name="n"), + body=BlockNode(statements=[]), + estimated_iterations="n" + ) + conditional = ConditionalNode( + then_branch=BlockNode(statements=[inner_loop]), + else_branch=None + ) + outer_loop = LoopNode( + loop_type=LoopType.FOR, + iterator=VariableNode(name="i"), + start=LiteralNode(value=1), + end=VariableNode(name="n"), + body=BlockNode(statements=[conditional]), + estimated_iterations="n" + ) + func = FunctionNode( + name="test", + parameters=[], + body=BlockNode(statements=[outer_loop]) + ) + ast = ProgramNode(functions=[func]) + + result = analyzer.analyze("", ast) + # Loop inside conditional inside loop = quadratic + assert result.time_complexity == ComplexityClass.QUADRATIC + + +# ============================================================================ +# ComplexityAnalyzer - Iteration Estimation Tests +# ============================================================================ + +class TestIterationEstimation: + """Tests for iteration estimation logic.""" + + def test_constant_bounds(self, analyzer): + """Test constant iteration bounds.""" + code = """ + FOR i = 1 TO 10 DO + print(i) + END FOR + """ + result = analyzer.analyze(code) + assert result.loops[0].iterations == "10" + assert result.loops[0].complexity == ComplexityClass.CONSTANT + + def test_variable_n_bound(self, analyzer): + """Test variable 'n' as upper bound.""" + code = """ + FOR i = 1 TO n DO + print(i) + END FOR + """ + result = analyzer.analyze(code) + assert result.loops[0].iterations == "n" + assert result.loops[0].complexity == ComplexityClass.LINEAR + + def test_length_bound(self, analyzer): + """Test 'length' as upper bound.""" + code = """ + FOR i = 0 TO length DO + print(i) + END FOR + """ + result = analyzer.analyze(code) + assert result.loops[0].complexity == ComplexityClass.LINEAR + + +# ============================================================================ +# FeedbackGenerator Tests +# ============================================================================ + +class TestFeedbackGenerator: + """Tests for FeedbackGenerator.""" + + def test_generate_constant_feedback(self, feedback_generator, analyzer): + """Test feedback for constant complexity.""" + code = "x = 5" + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result) + + assert feedback.complexity_result == "O(1)" + assert "Constant" in feedback.summary or "O(1)" in feedback.summary + + def test_generate_linear_feedback(self, feedback_generator, analyzer): + """Test feedback for linear complexity.""" + code = """ + FOR i = 1 TO n DO + print(i) + END FOR + """ + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result) + + assert feedback.complexity_result == "O(n)" + assert feedback.loop_count == 1 + assert feedback.max_nesting == 1 + + def test_generate_quadratic_feedback(self, feedback_generator, analyzer): + """Test feedback for quadratic complexity.""" + code = """ + FOR i = 1 TO n DO + FOR j = 1 TO n DO + print(i, j) + END FOR + END FOR + """ + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result) + + assert feedback.complexity_result == "O(n²)" + assert feedback.max_nesting == 2 + assert "nested" in feedback.summary.lower() or "loop" in feedback.summary.lower() + + def test_generate_recursion_feedback(self, feedback_generator, analyzer): + """Test feedback for recursive algorithms.""" + code = """ + FUNCTION fib(n) + IF n <= 1 THEN + RETURN n + END IF + RETURN fib(n - 1) + fib(n - 2) + END FUNCTION + """ + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result) + + assert feedback.has_recursion + assert "recursion" in feedback.summary.lower() + + def test_feedback_sections(self, feedback_generator, analyzer): + """Test that feedback sections are generated.""" + code = """ + FOR i = 1 TO n DO + FOR j = 1 TO n DO + print(i, j) + END FOR + END FOR + """ + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result, FeedbackLevel.DETAILED) + + assert len(feedback.sections) > 0 + section_titles = [s.title for s in feedback.sections] + assert any("Loop" in t for t in section_titles) + + def test_feedback_suggestions(self, feedback_generator, analyzer): + """Test that suggestions are generated for complex algorithms.""" + code = """ + FUNCTION fib(n) + IF n <= 1 THEN + RETURN n + END IF + RETURN fib(n - 1) + fib(n - 2) + END FUNCTION + """ + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result) + + # Exponential should have optimization suggestions + assert len(feedback.suggestions) > 0 + assert any("dynamic" in s.lower() or "memoization" in s.lower() + for s in feedback.suggestions) + + +# ============================================================================ +# FeedbackGenerator - Output Format Tests +# ============================================================================ + +class TestFeedbackFormats: + """Tests for feedback output formats.""" + + def test_to_string_brief(self, feedback_generator, analyzer): + """Test brief string format.""" + code = "FOR i = 1 TO n DO\n print(i)\nEND FOR" + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result, FeedbackLevel.BRIEF) + + output = feedback.to_string(FeedbackLevel.BRIEF) + assert "O(n)" in output + assert "COMPLEXITY ANALYSIS RESULT" in output + + def test_to_string_standard(self, feedback_generator, analyzer): + """Test standard string format.""" + code = "FOR i = 1 TO n DO\n print(i)\nEND FOR" + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result, FeedbackLevel.STANDARD) + + output = feedback.to_string(FeedbackLevel.STANDARD) + assert "O(n)" in output + assert "What does this mean?" in output + + def test_to_string_detailed(self, feedback_generator, analyzer): + """Test detailed string format.""" + code = "FOR i = 1 TO n DO\n print(i)\nEND FOR" + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result, FeedbackLevel.DETAILED) + + output = feedback.to_string(FeedbackLevel.DETAILED) + assert "O(n)" in output + assert "Real-World" in output or "Analysis Confidence" in output + + def test_to_dict(self, feedback_generator, analyzer): + """Test dictionary conversion.""" + code = "FOR i = 1 TO n DO\n print(i)\nEND FOR" + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result) + + data = feedback.to_dict() + assert "summary" in data + assert "complexity" in data + assert "sections" in data + assert "stats" in data + assert data["complexity"] == "O(n)" + + def test_format_for_student(self, feedback_generator, analyzer): + """Test student-friendly format.""" + code = """ + FOR i = 1 TO n DO + FOR j = 1 TO n DO + print(i, j) + END FOR + END FOR + """ + result = analyzer.analyze(code) + output = feedback_generator.format_for_student(result) + + assert "O(n²)" in output + assert len(output) > 100 # Should be detailed + + def test_format_brief(self, feedback_generator, analyzer): + """Test brief one-line format.""" + code = "FOR i = 1 TO n DO\n print(i)\nEND FOR" + result = analyzer.analyze(code) + output = feedback_generator.format_brief(result) + + assert "O(n)" in output + assert "Time Complexity" in output + + +# ============================================================================ +# FeedbackGenerator - Complexity Explanations +# ============================================================================ + +class TestComplexityExplanations: + """Tests for complexity class explanations.""" + + def test_constant_explanation(self, feedback_generator, analyzer): + """Test constant complexity explanation.""" + code = "x = 5" + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result) + + assert "Constant" in feedback.complexity_explanation + assert "same amount of time" in feedback.complexity_explanation.lower() or "O(1)" in feedback.complexity_explanation + + def test_linear_explanation(self, feedback_generator, analyzer): + """Test linear complexity explanation.""" + code = "FOR i = 1 TO n DO\n print(i)\nEND FOR" + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result) + + assert "Linear" in feedback.complexity_explanation + + def test_quadratic_explanation(self, feedback_generator, analyzer): + """Test quadratic complexity explanation.""" + code = """ + FOR i = 1 TO n DO + FOR j = 1 TO n DO + print(i, j) + END FOR + END FOR + """ + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result) + + assert "Quadratic" in feedback.complexity_explanation + assert "nested" in feedback.complexity_explanation.lower() or "n²" in feedback.complexity_explanation + + +# ============================================================================ +# Edge Cases and Special Scenarios +# ============================================================================ + +class TestEdgeCases: + """Tests for edge cases and special scenarios.""" + + def test_empty_code(self, analyzer): + """Test handling of empty code.""" + result = analyzer.analyze("") + assert result.time_complexity == ComplexityClass.CONSTANT + assert len(result.loops) == 0 + + def test_whitespace_only(self, analyzer): + """Test handling of whitespace-only code.""" + result = analyzer.analyze(" \n\n \t ") + assert result.time_complexity == ComplexityClass.CONSTANT + + def test_comments_only(self, analyzer): + """Test handling of comments-only code.""" + code = """ + // This is a comment + # Another comment + """ + result = analyzer.analyze(code) + assert result.time_complexity == ComplexityClass.CONSTANT + + def test_multiple_independent_loops(self, analyzer): + """Test multiple independent (not nested) loops.""" + code = """ + FOR i = 1 TO n DO + print(i) + END FOR + + FOR j = 1 TO n DO + print(j) + END FOR + """ + result = analyzer.analyze(code) + # Independent loops don't multiply - take max + assert result.time_complexity == ComplexityClass.LINEAR + assert len(result.loops) == 2 + + def test_deeply_nested_loops(self, analyzer): + """Test very deeply nested loops.""" + code = """ + FOR i = 1 TO n DO + FOR j = 1 TO n DO + FOR k = 1 TO n DO + FOR l = 1 TO n DO + print(i, j, k, l) + END FOR + END FOR + END FOR + END FOR + """ + result = analyzer.analyze(code) + # 4 nested loops = O(n^4) = polynomial + assert result.time_complexity in [ComplexityClass.POLYNOMIAL, ComplexityClass.CUBIC] + assert result.max_nesting_depth >= 3 + + def test_mixed_loop_types(self, analyzer): + """Test code with different loop types.""" + code = """ + FOR i = 1 TO n DO + WHILE j < n DO + j = j + 1 + END WHILE + END FOR + """ + result = analyzer.analyze(code) + assert result.time_complexity == ComplexityClass.QUADRATIC + + def test_loop_with_constant_bound(self, analyzer): + """Test loop with constant upper bound (not O(n)).""" + code = """ + FOR i = 1 TO 100 DO + print(i) + END FOR + """ + result = analyzer.analyze(code) + # Constant bound = O(1) + assert result.loops[0].complexity == ComplexityClass.CONSTANT + + def test_confidence_levels(self, analyzer): + """Test that confidence is set appropriately.""" + # Simple code - high confidence + simple_result = analyzer.analyze("x = 5") + assert simple_result.confidence >= 0.7 + + # Complex code with loops - high confidence + loop_code = "FOR i = 1 TO n DO\n print(i)\nEND FOR" + loop_result = analyzer.analyze(loop_code) + assert loop_result.confidence >= 0.8 + + +# ============================================================================ +# Integration Tests +# ============================================================================ + +class TestAnalyzerIntegration: + """Integration tests combining analyzer and feedback.""" + + def test_full_pipeline_linear(self, analyzer, feedback_generator): + """Test full analysis pipeline for linear algorithm.""" + code = """ + FUNCTION linearSearch(A, n, target) + FOR i = 1 TO n DO + IF A[i] == target THEN + RETURN i + END IF + END FOR + RETURN -1 + END FUNCTION + """ + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result, FeedbackLevel.DETAILED) + + assert result.time_complexity == ComplexityClass.LINEAR + assert "O(n)" in feedback.to_string() + + def test_full_pipeline_quadratic(self, analyzer, feedback_generator): + """Test full analysis pipeline for quadratic algorithm.""" + code = """ + FUNCTION bubbleSort(A, n) + FOR i = 1 TO n DO + FOR j = 1 TO n-1 DO + IF A[j] > A[j+1] THEN + swap(A[j], A[j+1]) + END IF + END FOR + END FOR + END FUNCTION + """ + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result, FeedbackLevel.DETAILED) + + assert result.time_complexity == ComplexityClass.QUADRATIC + assert "O(n²)" in feedback.to_string() + assert "nested" in feedback.to_string().lower() + + def test_full_pipeline_recursive(self, analyzer, feedback_generator): + """Test full analysis pipeline for recursive algorithm.""" + code = """ + FUNCTION factorial(n) + IF n <= 1 THEN + RETURN 1 + END IF + RETURN n * factorial(n - 1) + END FUNCTION + """ + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result, FeedbackLevel.DETAILED) + + assert result.recursion is not None + output = feedback.to_string() + assert "recursion" in output.lower() or "recursive" in output.lower() + + def test_full_pipeline_exponential(self, analyzer, feedback_generator): + """Test full analysis pipeline for exponential algorithm.""" + code = """ + FUNCTION fib(n) + IF n <= 1 THEN + RETURN n + END IF + RETURN fib(n-1) + fib(n-2) + END FUNCTION + """ + result = analyzer.analyze(code) + feedback = feedback_generator.generate(result, FeedbackLevel.DETAILED) + + assert result.time_complexity == ComplexityClass.EXPONENTIAL + output = feedback.to_string() + assert "O(2^n)" in output or "Exponential" in output + # Should have optimization suggestions + assert len(feedback.suggestions) > 0 + + +# ============================================================================ +# FeedbackSection Tests +# ============================================================================ + +class TestFeedbackSection: + """Tests for FeedbackSection class.""" + + def test_section_creation(self): + """Test creating a feedback section.""" + section = FeedbackSection( + title="Test Section", + content="This is test content", + importance="info" + ) + assert section.title == "Test Section" + assert section.content == "This is test content" + assert section.importance == "info" + + def test_section_importance_levels(self): + """Test different importance levels.""" + for importance in ["info", "warning", "success", "error"]: + section = FeedbackSection( + title="Test", + content="Content", + importance=importance + ) + assert section.importance == importance + + +# ============================================================================ +# DetailedFeedback Tests +# ============================================================================ + +class TestDetailedFeedback: + """Tests for DetailedFeedback class.""" + + def test_feedback_creation(self): + """Test creating detailed feedback.""" + feedback = DetailedFeedback( + summary="Test summary", + complexity_result="O(n)", + loop_count=1, + max_nesting=1, + has_recursion=False + ) + assert feedback.summary == "Test summary" + assert feedback.complexity_result == "O(n)" + assert feedback.loop_count == 1 + + def test_feedback_with_sections(self): + """Test feedback with sections.""" + section = FeedbackSection( + title="Loop Analysis", + content="One loop detected", + importance="info" + ) + feedback = DetailedFeedback( + summary="Test", + complexity_result="O(n)", + sections=[section] + ) + assert len(feedback.sections) == 1 + assert feedback.sections[0].title == "Loop Analysis" + + def test_feedback_to_dict_structure(self): + """Test dictionary structure of feedback.""" + feedback = DetailedFeedback( + summary="Test summary", + complexity_result="O(n²)", + loop_count=2, + max_nesting=2, + has_recursion=False, + complexity_explanation="Quadratic complexity", + real_world_example="Bubble sort", + suggestions=["Consider optimization"], + confidence_note="High confidence" + ) + data = feedback.to_dict() + + assert data["summary"] == "Test summary" + assert data["complexity"] == "O(n²)" + assert data["stats"]["loop_count"] == 2 + assert data["stats"]["max_nesting"] == 2 + assert data["stats"]["has_recursion"] == False + assert data["explanation"] == "Quadratic complexity" + assert "Consider optimization" in data["suggestions"] diff --git a/evaluation_function/tests/test_ast_builder.py b/evaluation_function/tests/test_ast_builder.py new file mode 100644 index 0000000..3dd5e80 --- /dev/null +++ b/evaluation_function/tests/test_ast_builder.py @@ -0,0 +1,723 @@ +""" +Comprehensive tests for the AST Builder module. + +Tests cover: +- AST node creation +- Node type verification +- Tree structure validation +- Expression handling +- Statement handling +""" + +import pytest +from ..schemas.ast_nodes import ( + ProgramNode, FunctionNode, BlockNode, LoopNode, ConditionalNode, + AssignmentNode, ReturnNode, FunctionCallNode, RecursiveCallNode, + VariableNode, LiteralNode, BinaryOpNode, UnaryOpNode, ArrayAccessNode, + NodeType, LoopType, OperatorType, SourceLocation +) + + +class TestASTNodeTypes: + """Tests for AST node type enumeration.""" + + def test_node_types_exist(self): + """Test that all expected node types exist.""" + expected_types = [ + NodeType.PROGRAM, NodeType.FUNCTION, NodeType.BLOCK, + NodeType.LOOP, NodeType.CONDITIONAL, NodeType.ASSIGNMENT, + NodeType.RETURN, NodeType.FUNCTION_CALL, NodeType.RECURSIVE_CALL, + NodeType.VARIABLE, NodeType.LITERAL, NodeType.BINARY_OP, + NodeType.UNARY_OP, NodeType.ARRAY_ACCESS, NodeType.EXPRESSION, + ] + + for node_type in expected_types: + assert node_type is not None + + def test_loop_types_exist(self): + """Test that all expected loop types exist.""" + expected_types = [ + LoopType.FOR, LoopType.FOR_EACH, LoopType.WHILE, + LoopType.DO_WHILE, LoopType.REPEAT_UNTIL, LoopType.UNKNOWN, + ] + + for loop_type in expected_types: + assert loop_type is not None + + def test_operator_types_exist(self): + """Test that all expected operator types exist.""" + arithmetic_ops = [ + OperatorType.ADD, OperatorType.SUBTRACT, OperatorType.MULTIPLY, + OperatorType.DIVIDE, OperatorType.MODULO, OperatorType.POWER, + ] + comparison_ops = [ + OperatorType.EQUAL, OperatorType.NOT_EQUAL, OperatorType.LESS_THAN, + OperatorType.LESS_EQUAL, OperatorType.GREATER_THAN, OperatorType.GREATER_EQUAL, + ] + logical_ops = [ + OperatorType.AND, OperatorType.OR, OperatorType.NOT, + ] + + for op in arithmetic_ops + comparison_ops + logical_ops: + assert op is not None + + +class TestProgramNode: + """Tests for ProgramNode.""" + + def test_create_empty_program(self): + """Test creating empty program node.""" + program = ProgramNode() + + assert program.node_type == NodeType.PROGRAM + assert len(program.functions) == 0 + assert program.global_statements is None + + def test_create_program_with_functions(self): + """Test creating program with functions.""" + func = FunctionNode(name="test", parameters=[], body=None) + program = ProgramNode(functions=[func]) + + assert len(program.functions) == 1 + assert program.functions[0].name == "test" + + def test_create_program_with_global_statements(self): + """Test creating program with global statements.""" + stmt = AssignmentNode( + target=VariableNode(name="x"), + value=LiteralNode(value=1, literal_type="int") + ) + block = BlockNode(statements=[stmt]) + program = ProgramNode(global_statements=block) + + assert program.global_statements is not None + assert len(program.global_statements.statements) == 1 + + def test_program_to_dict(self): + """Test program serialization to dict.""" + program = ProgramNode() + result = program.model_dump() + + assert "node_type" in result + assert "functions" in result + assert result["node_type"] == NodeType.PROGRAM + + +class TestFunctionNode: + """Tests for FunctionNode.""" + + def test_create_simple_function(self): + """Test creating simple function node.""" + func = FunctionNode(name="test", parameters=[], body=None) + + assert func.node_type == NodeType.FUNCTION + assert func.name == "test" + assert len(func.parameters) == 0 + assert func.is_recursive == False + + def test_create_function_with_parameters(self): + """Test creating function with parameters.""" + params = [ + VariableNode(name="a"), + VariableNode(name="b"), + ] + func = FunctionNode(name="add", parameters=params, body=None) + + assert len(func.parameters) == 2 + assert func.parameters[0].name == "a" + assert func.parameters[1].name == "b" + + def test_create_function_with_body(self): + """Test creating function with body.""" + return_stmt = ReturnNode(value=LiteralNode(value=1, literal_type="int")) + body = BlockNode(statements=[return_stmt]) + func = FunctionNode(name="getOne", parameters=[], body=body) + + assert func.body is not None + assert len(func.body.statements) == 1 + + def test_create_recursive_function(self): + """Test creating recursive function.""" + func = FunctionNode(name="factorial", parameters=[], body=None, is_recursive=True) + + assert func.is_recursive == True + + def test_function_to_dict(self): + """Test function serialization to dict.""" + func = FunctionNode(name="test", parameters=[], body=None) + result = func.model_dump() + + assert result["name"] == "test" + assert result["node_type"] == NodeType.FUNCTION + + +class TestBlockNode: + """Tests for BlockNode.""" + + def test_create_empty_block(self): + """Test creating empty block node.""" + block = BlockNode() + + assert block.node_type == NodeType.BLOCK + assert len(block.statements) == 0 + + def test_create_block_with_statements(self): + """Test creating block with statements.""" + stmt1 = AssignmentNode( + target=VariableNode(name="x"), + value=LiteralNode(value=1, literal_type="int") + ) + stmt2 = AssignmentNode( + target=VariableNode(name="y"), + value=LiteralNode(value=2, literal_type="int") + ) + block = BlockNode(statements=[stmt1, stmt2]) + + assert len(block.statements) == 2 + + def test_nested_blocks(self): + """Test nested blocks.""" + inner = BlockNode(statements=[]) + outer = BlockNode(statements=[inner]) + + assert len(outer.statements) == 1 + + +class TestLoopNode: + """Tests for LoopNode.""" + + def test_create_for_loop(self): + """Test creating FOR loop node.""" + loop = LoopNode( + loop_type=LoopType.FOR, + iterator=VariableNode(name="i"), + start=LiteralNode(value=1, literal_type="int"), + end=VariableNode(name="n"), + body=BlockNode(statements=[]) + ) + + assert loop.node_type == NodeType.LOOP + assert loop.loop_type == LoopType.FOR + assert loop.iterator.name == "i" + + def test_create_for_loop_with_step(self): + """Test creating FOR loop with step.""" + loop = LoopNode( + loop_type=LoopType.FOR, + iterator=VariableNode(name="i"), + start=LiteralNode(value=1, literal_type="int"), + end=VariableNode(name="n"), + step=LiteralNode(value=2, literal_type="int"), + body=BlockNode(statements=[]) + ) + + assert loop.step is not None + assert loop.step.value == 2 + + def test_create_while_loop(self): + """Test creating WHILE loop node.""" + condition = BinaryOpNode( + operator=OperatorType.LESS_THAN, + left=VariableNode(name="i"), + right=VariableNode(name="n") + ) + loop = LoopNode( + loop_type=LoopType.WHILE, + condition=condition, + body=BlockNode(statements=[]) + ) + + assert loop.loop_type == LoopType.WHILE + assert loop.condition is not None + + def test_create_foreach_loop(self): + """Test creating FOR-EACH loop node.""" + loop = LoopNode( + loop_type=LoopType.FOR_EACH, + iterator=VariableNode(name="item"), + collection=VariableNode(name="list"), + body=BlockNode(statements=[]) + ) + + assert loop.loop_type == LoopType.FOR_EACH + assert loop.collection is not None + + def test_create_repeat_until_loop(self): + """Test creating REPEAT-UNTIL loop node.""" + condition = BinaryOpNode( + operator=OperatorType.GREATER_EQUAL, + left=VariableNode(name="x"), + right=VariableNode(name="n") + ) + loop = LoopNode( + loop_type=LoopType.REPEAT_UNTIL, + condition=condition, + body=BlockNode(statements=[]) + ) + + assert loop.loop_type == LoopType.REPEAT_UNTIL + + def test_loop_nesting_level(self): + """Test loop nesting level.""" + inner_loop = LoopNode( + loop_type=LoopType.FOR, + iterator=VariableNode(name="j"), + nesting_level=1, + body=BlockNode(statements=[]) + ) + outer_loop = LoopNode( + loop_type=LoopType.FOR, + iterator=VariableNode(name="i"), + nesting_level=0, + body=BlockNode(statements=[inner_loop]) + ) + + assert outer_loop.nesting_level == 0 + assert inner_loop.nesting_level == 1 + + def test_loop_estimated_iterations(self): + """Test loop estimated iterations.""" + loop = LoopNode( + loop_type=LoopType.FOR, + estimated_iterations="n", + body=BlockNode(statements=[]) + ) + + assert loop.estimated_iterations == "n" + + def test_loop_to_dict(self): + """Test loop serialization to dict.""" + loop = LoopNode( + loop_type=LoopType.FOR, + iterator=VariableNode(name="i"), + body=BlockNode(statements=[]) + ) + result = loop.model_dump() + + assert result["loop_type"] == LoopType.FOR + assert result["node_type"] == NodeType.LOOP + + +class TestConditionalNode: + """Tests for ConditionalNode.""" + + def test_create_simple_if(self): + """Test creating simple IF node.""" + condition = BinaryOpNode( + operator=OperatorType.GREATER_THAN, + left=VariableNode(name="x"), + right=LiteralNode(value=0, literal_type="int") + ) + cond = ConditionalNode( + condition=condition, + then_branch=BlockNode(statements=[]) + ) + + assert cond.node_type == NodeType.CONDITIONAL + assert cond.condition is not None + assert cond.then_branch is not None + assert cond.else_branch is None + + def test_create_if_else(self): + """Test creating IF-ELSE node.""" + condition = BinaryOpNode( + operator=OperatorType.GREATER_THAN, + left=VariableNode(name="x"), + right=LiteralNode(value=0, literal_type="int") + ) + cond = ConditionalNode( + condition=condition, + then_branch=BlockNode(statements=[]), + else_branch=BlockNode(statements=[]) + ) + + assert cond.else_branch is not None + + def test_create_if_elif_else(self): + """Test creating IF-ELIF-ELSE node.""" + condition1 = BinaryOpNode( + operator=OperatorType.GREATER_THAN, + left=VariableNode(name="x"), + right=LiteralNode(value=0, literal_type="int") + ) + condition2 = BinaryOpNode( + operator=OperatorType.LESS_THAN, + left=VariableNode(name="x"), + right=LiteralNode(value=0, literal_type="int") + ) + elif_branch = ConditionalNode( + condition=condition2, + then_branch=BlockNode(statements=[]) + ) + cond = ConditionalNode( + condition=condition1, + then_branch=BlockNode(statements=[]), + elif_branches=[elif_branch], + else_branch=BlockNode(statements=[]) + ) + + assert len(cond.elif_branches) == 1 + + +class TestAssignmentNode: + """Tests for AssignmentNode.""" + + def test_create_simple_assignment(self): + """Test creating simple assignment node.""" + assign = AssignmentNode( + target=VariableNode(name="x"), + value=LiteralNode(value=1, literal_type="int") + ) + + assert assign.node_type == NodeType.ASSIGNMENT + assert assign.target.name == "x" + assert assign.value.value == 1 + + def test_create_array_assignment(self): + """Test creating array element assignment.""" + assign = AssignmentNode( + target=ArrayAccessNode( + array=VariableNode(name="A"), + index=VariableNode(name="i") + ), + value=LiteralNode(value=0, literal_type="int") + ) + + assert isinstance(assign.target, ArrayAccessNode) + + def test_create_compound_assignment(self): + """Test creating compound assignment.""" + assign = AssignmentNode( + target=VariableNode(name="x"), + value=LiteralNode(value=1, literal_type="int"), + operator=OperatorType.ADD_ASSIGN + ) + + assert assign.operator == OperatorType.ADD_ASSIGN + + +class TestExpressionNodes: + """Tests for expression nodes.""" + + def test_create_variable_node(self): + """Test creating variable node.""" + var = VariableNode(name="count") + + assert var.node_type == NodeType.VARIABLE + assert var.name == "count" + + def test_create_literal_int(self): + """Test creating integer literal.""" + lit = LiteralNode(value=42, literal_type="int") + + assert lit.node_type == NodeType.LITERAL + assert lit.value == 42 + assert lit.literal_type == "int" + + def test_create_literal_float(self): + """Test creating float literal.""" + lit = LiteralNode(value=3.14, literal_type="float") + + assert lit.value == 3.14 + assert lit.literal_type == "float" + + def test_create_literal_string(self): + """Test creating string literal.""" + lit = LiteralNode(value="hello", literal_type="string") + + assert lit.value == "hello" + assert lit.literal_type == "string" + + def test_create_literal_bool(self): + """Test creating boolean literal.""" + lit_true = LiteralNode(value=True, literal_type="bool") + lit_false = LiteralNode(value=False, literal_type="bool") + + assert lit_true.value == True + assert lit_false.value == False + + def test_create_binary_op_arithmetic(self): + """Test creating arithmetic binary operations.""" + operators = [ + OperatorType.ADD, OperatorType.SUBTRACT, + OperatorType.MULTIPLY, OperatorType.DIVIDE, + OperatorType.MODULO, OperatorType.POWER, + ] + + for op in operators: + node = BinaryOpNode( + operator=op, + left=VariableNode(name="a"), + right=VariableNode(name="b") + ) + assert node.node_type == NodeType.BINARY_OP + assert node.operator == op + + def test_create_binary_op_comparison(self): + """Test creating comparison binary operations.""" + operators = [ + OperatorType.EQUAL, OperatorType.NOT_EQUAL, + OperatorType.LESS_THAN, OperatorType.LESS_EQUAL, + OperatorType.GREATER_THAN, OperatorType.GREATER_EQUAL, + ] + + for op in operators: + node = BinaryOpNode( + operator=op, + left=VariableNode(name="a"), + right=VariableNode(name="b") + ) + assert node.operator == op + + def test_create_binary_op_logical(self): + """Test creating logical binary operations.""" + and_node = BinaryOpNode( + operator=OperatorType.AND, + left=VariableNode(name="a"), + right=VariableNode(name="b") + ) + or_node = BinaryOpNode( + operator=OperatorType.OR, + left=VariableNode(name="a"), + right=VariableNode(name="b") + ) + + assert and_node.operator == OperatorType.AND + assert or_node.operator == OperatorType.OR + + def test_create_unary_op(self): + """Test creating unary operations.""" + not_node = UnaryOpNode( + operator=OperatorType.NOT, + operand=VariableNode(name="flag") + ) + neg_node = UnaryOpNode( + operator=OperatorType.SUBTRACT, + operand=VariableNode(name="x") + ) + + assert not_node.node_type == NodeType.UNARY_OP + assert neg_node.operator == OperatorType.SUBTRACT + + def test_create_array_access_simple(self): + """Test creating simple array access.""" + access = ArrayAccessNode( + array=VariableNode(name="A"), + index=VariableNode(name="i") + ) + + assert access.node_type == NodeType.ARRAY_ACCESS + assert access.array.name == "A" + + def test_create_array_access_2d(self): + """Test creating 2D array access.""" + inner = ArrayAccessNode( + array=VariableNode(name="A"), + index=VariableNode(name="i") + ) + outer = ArrayAccessNode( + array=inner, + index=VariableNode(name="j") + ) + + assert isinstance(outer.array, ArrayAccessNode) + + def test_create_complex_expression(self): + """Test creating complex nested expression.""" + # Build: (a + b) * (c - d) + add = BinaryOpNode( + operator=OperatorType.ADD, + left=VariableNode(name="a"), + right=VariableNode(name="b") + ) + sub = BinaryOpNode( + operator=OperatorType.SUBTRACT, + left=VariableNode(name="c"), + right=VariableNode(name="d") + ) + mul = BinaryOpNode( + operator=OperatorType.MULTIPLY, + left=add, + right=sub + ) + + assert mul.operator == OperatorType.MULTIPLY + assert isinstance(mul.left, BinaryOpNode) + assert isinstance(mul.right, BinaryOpNode) + + +class TestFunctionCallNode: + """Tests for FunctionCallNode.""" + + def test_create_function_call_no_args(self): + """Test creating function call with no arguments.""" + call = FunctionCallNode( + function_name="test", + arguments=[] + ) + + assert call.node_type == NodeType.FUNCTION_CALL + assert call.function_name == "test" + assert len(call.arguments) == 0 + + def test_create_function_call_with_args(self): + """Test creating function call with arguments.""" + call = FunctionCallNode( + function_name="add", + arguments=[ + VariableNode(name="a"), + VariableNode(name="b") + ] + ) + + assert len(call.arguments) == 2 + + def test_create_recursive_call(self): + """Test creating recursive function call.""" + call = RecursiveCallNode( + function_name="factorial", + arguments=[ + BinaryOpNode( + operator=OperatorType.SUBTRACT, + left=VariableNode(name="n"), + right=LiteralNode(value=1, literal_type="int") + ) + ], + reduction_pattern="n-1", + branching_factor=1 + ) + + assert call.node_type == NodeType.RECURSIVE_CALL + assert call.reduction_pattern == "n-1" + assert call.branching_factor == 1 + + +class TestReturnNode: + """Tests for ReturnNode.""" + + def test_create_return_with_value(self): + """Test creating return with value.""" + ret = ReturnNode(value=VariableNode(name="result")) + + assert ret.node_type == NodeType.RETURN + assert ret.value is not None + + def test_create_return_no_value(self): + """Test creating return without value.""" + ret = ReturnNode() + + assert ret.value is None + + +class TestSourceLocation: + """Tests for SourceLocation.""" + + def test_create_source_location(self): + """Test creating source location.""" + loc = SourceLocation(line=1, column=0) + + assert loc.line == 1 + assert loc.column == 0 + + def test_create_source_location_with_end(self): + """Test creating source location with end position.""" + loc = SourceLocation(line=1, column=0, end_line=5, end_column=10) + + assert loc.end_line == 5 + assert loc.end_column == 10 + + def test_source_location_str(self): + """Test source location string representation.""" + loc1 = SourceLocation(line=1, column=0) + loc2 = SourceLocation(line=1, column=0, end_line=5) + + assert "line 1" in str(loc1) + assert "lines 1-5" in str(loc2) + + +class TestASTSerialization: + """Tests for AST serialization.""" + + def test_serialize_simple_program(self): + """Test serializing simple program to dict.""" + assign = AssignmentNode( + target=VariableNode(name="x"), + value=LiteralNode(value=1, literal_type="int") + ) + block = BlockNode(statements=[assign]) + program = ProgramNode(global_statements=block) + + result = program.model_dump() + + assert isinstance(result, dict) + assert "node_type" in result + assert "global_statements" in result + + def test_serialize_function(self): + """Test serializing function to dict.""" + func = FunctionNode( + name="test", + parameters=[VariableNode(name="x")], + body=BlockNode(statements=[ReturnNode(value=VariableNode(name="x"))]) + ) + + result = func.model_dump() + + assert result["name"] == "test" + assert len(result["parameters"]) == 1 + + def test_serialize_loop(self): + """Test serializing loop to dict.""" + loop = LoopNode( + loop_type=LoopType.FOR, + iterator=VariableNode(name="i"), + start=LiteralNode(value=1, literal_type="int"), + end=VariableNode(name="n"), + estimated_iterations="n", + body=BlockNode(statements=[]) + ) + + result = loop.model_dump() + + assert result["loop_type"] == LoopType.FOR + assert result["estimated_iterations"] == "n" + + def test_serialize_complex_ast(self): + """Test serializing complex AST.""" + # Create a function with loop and conditional + condition = BinaryOpNode( + operator=OperatorType.GREATER_THAN, + left=ArrayAccessNode( + array=VariableNode(name="A"), + index=VariableNode(name="i") + ), + right=LiteralNode(value=0, literal_type="int") + ) + + if_stmt = ConditionalNode( + condition=condition, + then_branch=BlockNode(statements=[ + FunctionCallNode(function_name="print", arguments=[VariableNode(name="i")]) + ]) + ) + + loop = LoopNode( + loop_type=LoopType.FOR, + iterator=VariableNode(name="i"), + start=LiteralNode(value=1, literal_type="int"), + end=VariableNode(name="n"), + body=BlockNode(statements=[if_stmt]) + ) + + func = FunctionNode( + name="printPositive", + parameters=[VariableNode(name="A"), VariableNode(name="n")], + body=BlockNode(statements=[loop]) + ) + + program = ProgramNode(functions=[func]) + + result = program.model_dump() + + assert len(result["functions"]) == 1 + assert result["functions"][0]["name"] == "printPositive" diff --git a/evaluation_function/tests/test_complexity_schemas.py b/evaluation_function/tests/test_complexity_schemas.py new file mode 100644 index 0000000..11b328d --- /dev/null +++ b/evaluation_function/tests/test_complexity_schemas.py @@ -0,0 +1,669 @@ +""" +Comprehensive tests for the Complexity Schemas module. + +Tests cover: +- ComplexityClass enum operations +- Complexity comparison and equivalence +- Loop and recursion complexity analysis +- Time and space complexity structures +""" + +import pytest +from ..schemas.complexity import ( + ComplexityClass, ComplexityExpression, ComplexityFactor, + LoopComplexity, RecursionComplexity, TimeComplexity, + SpaceComplexity, ComplexityResult +) + + +class TestComplexityClass: + """Tests for ComplexityClass enum.""" + + def test_complexity_class_values(self): + """Test that complexity classes have expected values.""" + assert ComplexityClass.CONSTANT.value == "O(1)" + assert ComplexityClass.LOGARITHMIC.value == "O(log n)" + assert ComplexityClass.LINEAR.value == "O(n)" + assert ComplexityClass.LINEARITHMIC.value == "O(n log n)" + assert ComplexityClass.QUADRATIC.value == "O(n²)" + assert ComplexityClass.CUBIC.value == "O(n³)" + assert ComplexityClass.EXPONENTIAL.value == "O(2^n)" + assert ComplexityClass.FACTORIAL.value == "O(n!)" + + def test_from_string_basic(self): + """Test parsing basic complexity strings.""" + test_cases = [ + ("O(1)", ComplexityClass.CONSTANT), + ("O(n)", ComplexityClass.LINEAR), + ("O(n^2)", ComplexityClass.QUADRATIC), + ("O(n^3)", ComplexityClass.CUBIC), + ("O(log n)", ComplexityClass.LOGARITHMIC), + ("O(n log n)", ComplexityClass.LINEARITHMIC), + ("O(2^n)", ComplexityClass.EXPONENTIAL), + ("O(n!)", ComplexityClass.FACTORIAL), + ] + + for input_str, expected in test_cases: + result = ComplexityClass.from_string(input_str) + assert result == expected, f"'{input_str}' should parse to {expected}" + + def test_from_string_variations(self): + """Test parsing complexity string variations.""" + # Case variations + assert ComplexityClass.from_string("o(n)") == ComplexityClass.LINEAR + assert ComplexityClass.from_string("O(N)") == ComplexityClass.LINEAR + + # Space variations + assert ComplexityClass.from_string("O( n )") == ComplexityClass.LINEAR + assert ComplexityClass.from_string("O(n log n)") == ComplexityClass.LINEARITHMIC + assert ComplexityClass.from_string("O(nlogn)") == ComplexityClass.LINEARITHMIC + + # Alternative notations + assert ComplexityClass.from_string("O(n²)") == ComplexityClass.QUADRATIC + assert ComplexityClass.from_string("O(n*n)") == ComplexityClass.QUADRATIC + assert ComplexityClass.from_string("O(nn)") == ComplexityClass.QUADRATIC + + # Logarithm variations + assert ComplexityClass.from_string("O(lgn)") == ComplexityClass.LOGARITHMIC + assert ComplexityClass.from_string("O(log(n))") == ComplexityClass.LOGARITHMIC + + def test_from_string_text_names(self): + """Test parsing text-based complexity names.""" + test_cases = [ + ("constant", ComplexityClass.CONSTANT), + ("linear", ComplexityClass.LINEAR), + ("quadratic", ComplexityClass.QUADRATIC), + ("cubic", ComplexityClass.CUBIC), + ("logarithmic", ComplexityClass.LOGARITHMIC), + ("linearithmic", ComplexityClass.LINEARITHMIC), + ("exponential", ComplexityClass.EXPONENTIAL), + ("factorial", ComplexityClass.FACTORIAL), + ] + + for input_str, expected in test_cases: + result = ComplexityClass.from_string(input_str) + assert result == expected + + def test_from_string_unknown(self): + """Test parsing unknown complexity strings.""" + unknown_inputs = ["O(mystery)", "unknown", "something", "", None] + + for input_str in unknown_inputs: + if input_str is not None: + result = ComplexityClass.from_string(input_str) + assert result == ComplexityClass.UNKNOWN + + def test_compare_complexities(self): + """Test comparing complexity classes.""" + # Lower is better (more efficient) + assert ComplexityClass.compare(ComplexityClass.CONSTANT, ComplexityClass.LINEAR) == -1 + assert ComplexityClass.compare(ComplexityClass.LINEAR, ComplexityClass.QUADRATIC) == -1 + assert ComplexityClass.compare(ComplexityClass.QUADRATIC, ComplexityClass.CUBIC) == -1 + assert ComplexityClass.compare(ComplexityClass.CUBIC, ComplexityClass.EXPONENTIAL) == -1 + + # Equal + assert ComplexityClass.compare(ComplexityClass.LINEAR, ComplexityClass.LINEAR) == 0 + + # Higher is worse (less efficient) + assert ComplexityClass.compare(ComplexityClass.QUADRATIC, ComplexityClass.LINEAR) == 1 + + def test_compare_with_unknown(self): + """Test comparing with UNKNOWN complexity.""" + result = ComplexityClass.compare(ComplexityClass.UNKNOWN, ComplexityClass.LINEAR) + assert result == 0 # Unknown comparisons return 0 + + def test_is_equivalent(self): + """Test complexity equivalence check.""" + assert ComplexityClass.LINEAR.is_equivalent(ComplexityClass.LINEAR) + assert not ComplexityClass.LINEAR.is_equivalent(ComplexityClass.QUADRATIC) + + def test_multiply_complexities(self): + """Test multiplying complexity classes.""" + # O(n) * O(n) = O(n²) + result = ComplexityClass.multiply(ComplexityClass.LINEAR, ComplexityClass.LINEAR) + assert result == ComplexityClass.QUADRATIC + + # O(n) * O(n²) = O(n³) + result = ComplexityClass.multiply(ComplexityClass.LINEAR, ComplexityClass.QUADRATIC) + assert result == ComplexityClass.CUBIC + + # O(n) * O(log n) = O(n log n) + result = ComplexityClass.multiply(ComplexityClass.LINEAR, ComplexityClass.LOGARITHMIC) + assert result == ComplexityClass.LINEARITHMIC + + # O(1) * O(n) = O(n) + result = ComplexityClass.multiply(ComplexityClass.CONSTANT, ComplexityClass.LINEAR) + assert result == ComplexityClass.LINEAR + + # O(n) * O(1) = O(n) + result = ComplexityClass.multiply(ComplexityClass.LINEAR, ComplexityClass.CONSTANT) + assert result == ComplexityClass.LINEAR + + def test_get_order(self): + """Test getting complexity order.""" + order = ComplexityClass.get_order() + + assert order[0] == ComplexityClass.CONSTANT + assert ComplexityClass.LOGARITHMIC in order + assert ComplexityClass.LINEAR in order + assert ComplexityClass.LINEARITHMIC in order + assert ComplexityClass.QUADRATIC in order + assert ComplexityClass.EXPONENTIAL in order + assert ComplexityClass.FACTORIAL in order + + +class TestComplexityExpression: + """Tests for ComplexityExpression.""" + + def test_create_expression(self): + """Test creating complexity expression.""" + expr = ComplexityExpression( + base_class=ComplexityClass.LINEAR, + raw_expression="O(n)" + ) + + assert expr.base_class == ComplexityClass.LINEAR + assert expr.raw_expression == "O(n)" + + def test_expression_with_coefficient(self): + """Test expression with coefficient.""" + expr = ComplexityExpression( + base_class=ComplexityClass.LINEAR, + raw_expression="O(2n)", + coefficient=2.0 + ) + + assert expr.coefficient == 2.0 + # Asymptotically still O(n) + assert expr.to_string() == "O(n)" + + def test_expression_equivalence(self): + """Test expression equivalence.""" + expr1 = ComplexityExpression( + base_class=ComplexityClass.LINEAR, + raw_expression="O(n)", + coefficient=1.0 + ) + expr2 = ComplexityExpression( + base_class=ComplexityClass.LINEAR, + raw_expression="O(2n)", + coefficient=2.0 + ) + expr3 = ComplexityExpression( + base_class=ComplexityClass.QUADRATIC, + raw_expression="O(n^2)" + ) + + assert expr1.is_equivalent(expr2) # O(n) == O(2n) asymptotically + assert not expr1.is_equivalent(expr3) + + +class TestComplexityFactor: + """Tests for ComplexityFactor.""" + + def test_create_factor(self): + """Test creating complexity factor.""" + factor = ComplexityFactor( + source="outer loop", + factor_type="loop", + complexity=ComplexityClass.LINEAR, + iterations="n", + nesting_level=0 + ) + + assert factor.source == "outer loop" + assert factor.factor_type == "loop" + assert factor.complexity == ComplexityClass.LINEAR + + +class TestLoopComplexity: + """Tests for LoopComplexity.""" + + def test_create_simple_loop(self): + """Test creating simple loop complexity.""" + loop = LoopComplexity( + loop_type="for", + iterator_variable="i", + iterations="n", + complexity=ComplexityClass.LINEAR + ) + + assert loop.loop_type == "for" + assert loop.iterations == "n" + assert loop.get_total_complexity() == ComplexityClass.LINEAR + + def test_nested_loop_complexity(self): + """Test nested loop complexity calculation.""" + inner = LoopComplexity( + loop_type="for", + iterator_variable="j", + iterations="n", + complexity=ComplexityClass.LINEAR, + nesting_level=1 + ) + outer = LoopComplexity( + loop_type="for", + iterator_variable="i", + iterations="n", + complexity=ComplexityClass.LINEAR, + nesting_level=0, + nested_loops=[inner] + ) + + # O(n) * O(n) = O(n²) + total = outer.get_total_complexity() + assert total == ComplexityClass.QUADRATIC + + def test_loop_with_bounds(self): + """Test loop with explicit bounds.""" + loop = LoopComplexity( + loop_type="for", + iterator_variable="i", + iterations="n", + complexity=ComplexityClass.LINEAR, + start_bound="1", + end_bound="n", + step_size="1" + ) + + assert loop.start_bound == "1" + assert loop.end_bound == "n" + + def test_loop_to_dict(self): + """Test loop complexity serialization.""" + loop = LoopComplexity( + loop_type="for", + iterator_variable="i", + iterations="n", + complexity=ComplexityClass.LINEAR + ) + + result = loop.model_dump() + + assert result["loop_type"] == "for" + assert result["iterations"] == "n" + + +class TestRecursionComplexity: + """Tests for RecursionComplexity.""" + + def test_create_simple_recursion(self): + """Test creating simple recursion complexity.""" + rec = RecursionComplexity( + function_name="factorial", + branching_factor=1, + reduction_type="subtract", + work_per_call=ComplexityClass.CONSTANT + ) + + assert rec.function_name == "factorial" + assert rec.branching_factor == 1 + + def test_analyze_linear_recursion(self): + """Test analyzing linear recursion (factorial-like).""" + rec = RecursionComplexity( + function_name="factorial", + branching_factor=1, + reduction_factor=1, + reduction_type="subtract", + work_per_call=ComplexityClass.CONSTANT + ) + + result = rec.analyze() + assert result == ComplexityClass.LINEAR + + def test_analyze_exponential_recursion(self): + """Test analyzing exponential recursion (naive Fibonacci).""" + rec = RecursionComplexity( + function_name="fib", + branching_factor=2, + reduction_factor=1, + reduction_type="subtract", + work_per_call=ComplexityClass.CONSTANT + ) + + result = rec.analyze() + assert result == ComplexityClass.EXPONENTIAL + + def test_analyze_divide_conquer_logarithmic(self): + """Test analyzing divide-and-conquer with O(1) work (binary search).""" + rec = RecursionComplexity( + function_name="binarySearch", + branching_factor=1, + reduction_factor=2.0, + reduction_type="divide", + work_per_call=ComplexityClass.CONSTANT + ) + + result = rec.analyze() + assert result == ComplexityClass.LOGARITHMIC + + def test_analyze_divide_conquer_linearithmic(self): + """Test analyzing divide-and-conquer with O(n) work (merge sort).""" + rec = RecursionComplexity( + function_name="mergeSort", + branching_factor=2, + reduction_factor=2.0, + reduction_type="divide", + work_per_call=ComplexityClass.LINEAR + ) + + result = rec.analyze() + assert result == ComplexityClass.LINEARITHMIC + + def test_recurrence_pattern(self): + """Test setting recurrence pattern.""" + rec = RecursionComplexity( + function_name="mergeSort", + branching_factor=2, + reduction_factor=2.0, + reduction_type="divide", + work_per_call=ComplexityClass.LINEAR, + recurrence_pattern="T(n) = 2T(n/2) + O(n)" + ) + + assert rec.recurrence_pattern == "T(n) = 2T(n/2) + O(n)" + + def test_recursion_to_dict(self): + """Test recursion complexity serialization.""" + rec = RecursionComplexity( + function_name="factorial", + branching_factor=1, + reduction_type="subtract" + ) + + result = rec.model_dump() + + assert result["function_name"] == "factorial" + assert result["branching_factor"] == 1 + + +class TestTimeComplexity: + """Tests for TimeComplexity.""" + + def test_create_time_complexity(self): + """Test creating time complexity result.""" + tc = TimeComplexity( + overall=ComplexityClass.QUADRATIC, + expression="O(n²)" + ) + + assert tc.overall == ComplexityClass.QUADRATIC + assert tc.expression == "O(n²)" + + def test_time_complexity_with_contributions(self): + """Test time complexity with loop contributions.""" + loop1 = LoopComplexity( + loop_type="for", + iterator_variable="i", + iterations="n", + complexity=ComplexityClass.LINEAR + ) + loop2 = LoopComplexity( + loop_type="for", + iterator_variable="j", + iterations="n", + complexity=ComplexityClass.LINEAR, + nesting_level=1 + ) + + tc = TimeComplexity( + overall=ComplexityClass.QUADRATIC, + expression="O(n²)", + loop_contributions=[loop1, loop2], + dominant_factor="nested loops" + ) + + assert len(tc.loop_contributions) == 2 + assert tc.dominant_factor == "nested loops" + + def test_time_complexity_cases(self): + """Test time complexity with best/average/worst cases.""" + tc = TimeComplexity( + overall=ComplexityClass.LINEARITHMIC, + expression="O(n log n)", + best_case=ComplexityClass.LINEAR, + average_case=ComplexityClass.LINEARITHMIC, + worst_case=ComplexityClass.QUADRATIC + ) + + assert tc.best_case == ComplexityClass.LINEAR + assert tc.worst_case == ComplexityClass.QUADRATIC + + def test_time_complexity_to_dict(self): + """Test time complexity serialization.""" + tc = TimeComplexity( + overall=ComplexityClass.LINEAR, + expression="O(n)", + explanation="Single loop iterating n times" + ) + + result = tc.model_dump() + + assert result["overall"] == ComplexityClass.LINEAR + assert "explanation" in result + + +class TestSpaceComplexity: + """Tests for SpaceComplexity.""" + + def test_create_space_complexity(self): + """Test creating space complexity result.""" + sc = SpaceComplexity( + overall=ComplexityClass.CONSTANT, + expression="O(1)" + ) + + assert sc.overall == ComplexityClass.CONSTANT + + def test_space_complexity_with_auxiliary(self): + """Test space complexity with auxiliary space.""" + sc = SpaceComplexity( + overall=ComplexityClass.LINEAR, + expression="O(n)", + auxiliary_space=ComplexityClass.LINEAR, + input_space=ComplexityClass.LINEAR + ) + + assert sc.auxiliary_space == ComplexityClass.LINEAR + + def test_space_complexity_with_recursion_stack(self): + """Test space complexity with recursion stack.""" + sc = SpaceComplexity( + overall=ComplexityClass.LINEAR, + expression="O(n)", + recursion_stack=ComplexityClass.LINEAR + ) + + assert sc.recursion_stack == ComplexityClass.LINEAR + + def test_space_complexity_data_structures(self): + """Test space complexity with data structures.""" + sc = SpaceComplexity( + overall=ComplexityClass.LINEAR, + expression="O(n)", + data_structures=[ + {"type": "array", "size": "n"}, + {"type": "hash_table", "size": "n"} + ] + ) + + assert len(sc.data_structures) == 2 + + +class TestComplexityResult: + """Tests for ComplexityResult.""" + + def test_create_complexity_result(self): + """Test creating complete complexity result.""" + tc = TimeComplexity( + overall=ComplexityClass.QUADRATIC, + expression="O(n²)" + ) + sc = SpaceComplexity( + overall=ComplexityClass.CONSTANT, + expression="O(1)" + ) + + result = ComplexityResult( + time_complexity=tc, + space_complexity=sc + ) + + assert result.time_complexity.overall == ComplexityClass.QUADRATIC + assert result.space_complexity.overall == ComplexityClass.CONSTANT + + def test_complexity_result_with_metadata(self): + """Test complexity result with metadata.""" + tc = TimeComplexity( + overall=ComplexityClass.LINEARITHMIC, + expression="O(n log n)" + ) + sc = SpaceComplexity( + overall=ComplexityClass.LINEAR, + expression="O(n)" + ) + + result = ComplexityResult( + time_complexity=tc, + space_complexity=sc, + algorithm_type="sorting", + is_optimal=True, + confidence=0.95, + optimization_suggestions=["Consider in-place sorting to reduce space"] + ) + + assert result.algorithm_type == "sorting" + assert result.is_optimal == True + assert result.confidence == 0.95 + assert len(result.optimization_suggestions) == 1 + + def test_complexity_result_with_warnings(self): + """Test complexity result with warnings.""" + tc = TimeComplexity( + overall=ComplexityClass.UNKNOWN, + expression="unknown" + ) + sc = SpaceComplexity( + overall=ComplexityClass.UNKNOWN, + expression="unknown" + ) + + result = ComplexityResult( + time_complexity=tc, + space_complexity=sc, + confidence=0.3, + warnings=["Could not determine loop bounds", "Recursion pattern unclear"] + ) + + assert len(result.warnings) == 2 + assert result.confidence == 0.3 + + def test_complexity_result_to_dict(self): + """Test complexity result serialization.""" + tc = TimeComplexity( + overall=ComplexityClass.LINEAR, + expression="O(n)" + ) + sc = SpaceComplexity( + overall=ComplexityClass.CONSTANT, + expression="O(1)" + ) + + result = ComplexityResult( + time_complexity=tc, + space_complexity=sc + ) + + data = result.model_dump() + + assert "time_complexity" in data + assert "space_complexity" in data + assert data["time_complexity"]["overall"] == ComplexityClass.LINEAR + + +class TestComplexityClassOrdering: + """Tests for complexity class ordering and comparisons.""" + + def test_complexity_ordering(self): + """Test that complexities are correctly ordered.""" + order = ComplexityClass.get_order() + + for i in range(len(order) - 1): + result = ComplexityClass.compare(order[i], order[i + 1]) + assert result == -1, f"{order[i]} should be less than {order[i + 1]}" + + def test_all_pairs_comparison(self): + """Test comparing all pairs of complexity classes.""" + order = ComplexityClass.get_order() + + for i, c1 in enumerate(order): + for j, c2 in enumerate(order): + result = ComplexityClass.compare(c1, c2) + if i < j: + assert result == -1 + elif i > j: + assert result == 1 + else: + assert result == 0 + + def test_symmetric_comparison(self): + """Test that comparisons are symmetric.""" + c1, c2 = ComplexityClass.LINEAR, ComplexityClass.QUADRATIC + + assert ComplexityClass.compare(c1, c2) == -ComplexityClass.compare(c2, c1) + + +class TestEdgeCases: + """Tests for edge cases in complexity schemas.""" + + def test_empty_expression(self): + """Test handling of empty complexity expression.""" + result = ComplexityClass.from_string("") + assert result == ComplexityClass.UNKNOWN + + def test_nested_loops_deep(self): + """Test deeply nested loop complexity.""" + # Create 5 nested loops + loops = [] + for i in range(5): + loop = LoopComplexity( + loop_type="for", + iterator_variable=f"i{i}", + iterations="n", + complexity=ComplexityClass.LINEAR, + nesting_level=i + ) + loops.append(loop) + + # Link them + for i in range(len(loops) - 1): + loops[i].nested_loops = [loops[i + 1]] + + # O(n^5) should be at least CUBIC or POLYNOMIAL + total = loops[0].get_total_complexity() + # Accept CUBIC or POLYNOMIAL since deep nesting multiplication is complex + assert total in [ComplexityClass.CUBIC, ComplexityClass.POLYNOMIAL], f"Expected CUBIC or POLYNOMIAL, got {total}" + + def test_zero_branching_factor(self): + """Test recursion with zero branching factor.""" + rec = RecursionComplexity( + function_name="test", + branching_factor=0, + reduction_type="subtract" + ) + + # Should not crash + result = rec.analyze() + assert result is not None + + def test_large_coefficient(self): + """Test expression with large coefficient.""" + expr = ComplexityExpression( + base_class=ComplexityClass.LINEAR, + raw_expression="O(1000000n)", + coefficient=1000000 + ) + + # Still O(n) asymptotically + assert expr.to_string() == "O(n)" diff --git a/evaluation_function/tests/test_evaluation.py b/evaluation_function/tests/test_evaluation.py new file mode 100644 index 0000000..d23e8aa --- /dev/null +++ b/evaluation_function/tests/test_evaluation.py @@ -0,0 +1,325 @@ +""" +Tests for the evaluation function. + +Tests cover: +- Basic evaluation functionality +- Complexity bound checking (code should be <= bound) +- Different complexity classes +- Partial credit scoring +- Feedback generation +- Error handling +""" + +import pytest + + +class MockParams: + """Mock params object for testing.""" + def __init__(self, **kwargs): + self._data = kwargs + + def __iter__(self): + return iter(self._data) + + def __getitem__(self, key): + return self._data[key] + + def to_dict(self): + return self._data + + +@pytest.fixture +def params(): + """Default params fixture.""" + return MockParams() + + +class TestEvaluationBasic: + """Basic evaluation function tests.""" + + def test_evaluation_returns_result(self, params): + """Test that evaluation returns a Result object.""" + from ..evaluation import evaluation_function + + response = "FOR i = 1 TO n DO\n x = x + 1\nEND FOR" + answer = "O(n)" + result = evaluation_function(response, answer, params) + + assert hasattr(result, 'is_correct') + assert hasattr(result, 'to_dict') + + def test_linear_meets_linear_bound(self, params): + """Test linear code meets O(n) bound.""" + from ..evaluation import evaluation_function + + response = "FOR i = 1 TO n DO\n x = x + 1\nEND FOR" + answer = "O(n)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is True + + def test_constant_meets_linear_bound(self, params): + """Test constant code meets O(n) bound (better than required).""" + from ..evaluation import evaluation_function + + response = "x = 1\ny = 2" + answer = "O(n)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is True + + def test_quadratic_exceeds_linear_bound(self, params): + """Test quadratic code exceeds O(n) bound.""" + from ..evaluation import evaluation_function + + response = """FOR i = 1 TO n DO + FOR j = 1 TO n DO + x = x + 1 + END FOR +END FOR""" + answer = "O(n)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is False + + def test_quadratic_meets_quadratic_bound(self, params): + """Test quadratic code meets O(n^2) bound.""" + from ..evaluation import evaluation_function + + response = """FOR i = 1 TO n DO + FOR j = 1 TO n DO + x = x + 1 + END FOR +END FOR""" + answer = "O(n^2)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is True + + +class TestEvaluationComplexityBounds: + """Test various complexity bounds.""" + + def test_log_n_meets_log_n_bound(self, params): + """Test O(log n) code meets O(log n) bound.""" + from ..evaluation import evaluation_function + + response = """FUNCTION binarySearch(A, target, low, high) + IF low > high THEN + RETURN -1 + END IF + mid = (low + high) / 2 + IF A[mid] == target THEN + RETURN mid + ELSE IF A[mid] < target THEN + RETURN binarySearch(A, target, mid + 1, high) + ELSE + RETURN binarySearch(A, target, low, mid - 1) + END IF +END FUNCTION""" + answer = "O(log n)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is True + + def test_linear_exceeds_log_n_bound(self, params): + """Test O(n) code exceeds O(log n) bound.""" + from ..evaluation import evaluation_function + + response = "FOR i = 1 TO n DO\n x = x + 1\nEND FOR" + answer = "O(log n)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is False + + def test_nlogn_meets_nlogn_bound(self, params): + """Test O(n log n) code meets O(n log n) bound.""" + from ..evaluation import evaluation_function + + response = """FUNCTION mergeSort(A, low, high) + IF low < high THEN + mid = (low + high) / 2 + mergeSort(A, low, mid) + mergeSort(A, mid + 1, high) + merge(A, low, mid, high) + END IF +END FUNCTION""" + answer = "O(n log n)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is True + + def test_linear_meets_nlogn_bound(self, params): + """Test O(n) code meets O(n log n) bound (better than required).""" + from ..evaluation import evaluation_function + + response = "FOR i = 1 TO n DO\n x = x + 1\nEND FOR" + answer = "O(n log n)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is True + + def test_cubic_meets_cubic_bound(self, params): + """Test O(n^3) code meets O(n^3) bound.""" + from ..evaluation import evaluation_function + + response = """FOR i = 1 TO n DO + FOR j = 1 TO n DO + FOR k = 1 TO n DO + x = x + 1 + END FOR + END FOR +END FOR""" + answer = "O(n^3)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is True + + +class TestEvaluationDictFormats: + """Test evaluation with dict response/answer formats.""" + + def test_dict_answer_time_complexity(self, params): + """Test dict answer with expected_time_complexity.""" + from ..evaluation import evaluation_function + + response = "FOR i = 1 TO n DO\n x = x + 1\nEND FOR" + answer = {"expected_time_complexity": "O(n)"} + result = evaluation_function(response, answer, params) + + assert result.is_correct is True + + def test_dict_response_with_pseudocode(self, params): + """Test dict response with pseudocode key.""" + from ..evaluation import evaluation_function + + response = {"pseudocode": "FOR i = 1 TO n DO\n x = x + 1\nEND FOR"} + answer = "O(n)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is True + + +class TestEvaluationFeedback: + """Test feedback generation.""" + + def test_feedback_present_in_result(self, params): + """Test feedback is present in result.""" + from ..evaluation import evaluation_function + + response = "FOR i = 1 TO n DO\n x = x + 1\nEND FOR" + answer = "O(n)" + result = evaluation_function(response, answer, params) + + result_dict = result.to_dict() + assert "feedback" in result_dict + assert len(result_dict["feedback"]) > 0 + + def test_feedback_shows_complexity(self, params): + """Test feedback shows detected complexity.""" + from ..evaluation import evaluation_function + + response = "FOR i = 1 TO n DO\n x = x + 1\nEND FOR" + answer = "O(n)" + result = evaluation_function(response, answer, params) + + assert "O(n)" in result.feedback + + def test_correct_feedback_positive(self, params): + """Test correct answer gets positive feedback.""" + from ..evaluation import evaluation_function + + response = "FOR i = 1 TO n DO\n x = x + 1\nEND FOR" + answer = "O(n)" + result = evaluation_function(response, answer, params) + + assert "Correct" in result.feedback or "meets" in result.feedback + + +class TestEvaluationErrorHandling: + """Test error handling.""" + + def test_empty_pseudocode(self, params): + """Test handling of empty pseudocode.""" + from ..evaluation import evaluation_function + + response = "" + answer = "O(n)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is False + assert "No pseudocode" in result.feedback + + def test_none_response(self, params): + """Test handling of None response.""" + from ..evaluation import evaluation_function + + response = None + answer = "O(n)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is False + + +class TestEvaluationComplexityVariants: + """Test different complexity notation variants.""" + + def test_accepts_n_squared_notation(self, params): + """Test accepts O(n^2) notation.""" + from ..evaluation import evaluation_function + + response = """FOR i = 1 TO n DO + FOR j = 1 TO n DO + x = x + 1 + END FOR +END FOR""" + answer = "O(n^2)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is True + + def test_accepts_unicode_squared(self, params): + """Test accepts O(n²) unicode notation.""" + from ..evaluation import evaluation_function + + response = """FOR i = 1 TO n DO + FOR j = 1 TO n DO + x = x + 1 + END FOR +END FOR""" + answer = "O(n²)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is True + + def test_accepts_quadratic_word(self, params): + """Test accepts 'quadratic' as answer.""" + from ..evaluation import evaluation_function + + response = """FOR i = 1 TO n DO + FOR j = 1 TO n DO + x = x + 1 + END FOR +END FOR""" + answer = "quadratic" + result = evaluation_function(response, answer, params) + + assert result.is_correct is True + + +class TestEvaluationCurlyBraceSyntax: + """Test evaluation with curly brace syntax.""" + + def test_curly_brace_loops(self, params): + """Test curly brace loop syntax.""" + from ..evaluation import evaluation_function + + response = """FOR i = 1 TO n { + FOR j = 1 TO n { + x = x + 1 + } +}""" + answer = "O(n^2)" + result = evaluation_function(response, answer, params) + + assert result.is_correct is True diff --git a/evaluation_function/tests/test_input_output_schemas.py b/evaluation_function/tests/test_input_output_schemas.py new file mode 100644 index 0000000..7ed7f45 --- /dev/null +++ b/evaluation_function/tests/test_input_output_schemas.py @@ -0,0 +1,639 @@ +""" +Comprehensive tests for Input and Output Schemas. + +Tests cover: +- StudentResponse validation +- ExpectedAnswer validation +- EvaluationParams configuration +- EvaluationResult structure +- Feedback items +- Parse results +""" + +import pytest +from pydantic import ValidationError + +from ..schemas.input_schema import ( + StudentResponse, ExpectedAnswer, EvaluationParams +) +from ..schemas.output_schema import ( + EvaluationResult, TimeComplexityResult, SpaceComplexityResult, + ComplexityAnalysis, ConstructAnalysis, FeedbackItem, FeedbackLevel, + ParseResult +) +from ..schemas.complexity import ComplexityClass, TimeComplexity, SpaceComplexity + + +class TestStudentResponse: + """Tests for StudentResponse schema.""" + + def test_create_minimal_response(self): + """Test creating response with only required field.""" + response = StudentResponse(pseudocode="x = 1") + + assert response.pseudocode == "x = 1" + assert response.time_complexity is None + assert response.space_complexity is None + + def test_create_full_response(self): + """Test creating response with all fields.""" + response = StudentResponse( + pseudocode="FOR i = 1 TO n DO\n print(i)\nEND FOR", + time_complexity="O(n)", + space_complexity="O(1)", + explanation="Single loop iterates n times" + ) + + assert response.pseudocode is not None + assert response.time_complexity == "O(n)" + assert response.space_complexity == "O(1)" + assert response.explanation is not None + + def test_pseudocode_validation_empty(self): + """Test that empty pseudocode raises validation error.""" + with pytest.raises(ValidationError): + StudentResponse(pseudocode="") + + def test_pseudocode_validation_whitespace_only(self): + """Test that whitespace-only pseudocode raises validation error.""" + with pytest.raises(ValidationError): + StudentResponse(pseudocode=" \n\t ") + + def test_pseudocode_stripped(self): + """Test that pseudocode is stripped of leading/trailing whitespace.""" + response = StudentResponse(pseudocode=" x = 1 ") + assert response.pseudocode == "x = 1" + + def test_to_dict(self): + """Test response serialization to dict.""" + response = StudentResponse( + pseudocode="x = 1", + time_complexity="O(1)" + ) + + result = response.model_dump() + + assert result["pseudocode"] == "x = 1" + assert result["time_complexity"] == "O(1)" + + def test_json_schema_example(self): + """Test that JSON schema example is valid.""" + schema = StudentResponse.model_json_schema() + + assert "example" in schema or "properties" in schema + + +class TestExpectedAnswer: + """Tests for ExpectedAnswer schema.""" + + def test_create_minimal_answer(self): + """Test creating answer with only required field.""" + answer = ExpectedAnswer(expected_time_complexity="O(n)") + + assert answer.expected_time_complexity == "O(n)" + assert answer.expected_space_complexity == "O(1)" # Default + + def test_create_full_answer(self): + """Test creating answer with all fields.""" + answer = ExpectedAnswer( + expected_time_complexity="O(n^2)", + expected_space_complexity="O(1)", + acceptable_time_alternatives=["O(n*n)", "O(n²)"], + acceptable_space_alternatives=["O(1)", "constant"], + algorithm_description="Bubble sort implementation", + algorithm_type="sorting", + expected_constructs=["nested_loop"], + time_complexity_weight=0.6, + space_complexity_weight=0.4 + ) + + assert answer.expected_time_complexity == "O(n^2)" + assert len(answer.acceptable_time_alternatives) == 2 + assert answer.algorithm_type == "sorting" + + def test_get_all_acceptable_time(self): + """Test getting all acceptable time complexities.""" + answer = ExpectedAnswer( + expected_time_complexity="O(n)", + acceptable_time_alternatives=["O(1*n)", "linear"] + ) + + all_acceptable = answer.get_all_acceptable_time() + + assert "O(n)" in all_acceptable + assert "O(1*n)" in all_acceptable + assert "linear" in all_acceptable + assert len(all_acceptable) == 3 + + def test_get_all_acceptable_space(self): + """Test getting all acceptable space complexities.""" + answer = ExpectedAnswer( + expected_time_complexity="O(n)", + expected_space_complexity="O(1)", + acceptable_space_alternatives=["constant"] + ) + + all_acceptable = answer.get_all_acceptable_space() + + assert "O(1)" in all_acceptable + assert "constant" in all_acceptable + + def test_weight_validation(self): + """Test that weights are validated to be between 0 and 1.""" + # Valid weights + answer = ExpectedAnswer( + expected_time_complexity="O(n)", + time_complexity_weight=0.7, + space_complexity_weight=0.3 + ) + assert answer.time_complexity_weight == 0.7 + + # Invalid weights should raise error + with pytest.raises(ValidationError): + ExpectedAnswer( + expected_time_complexity="O(n)", + time_complexity_weight=1.5 # > 1 + ) + + with pytest.raises(ValidationError): + ExpectedAnswer( + expected_time_complexity="O(n)", + time_complexity_weight=-0.1 # < 0 + ) + + +class TestEvaluationParams: + """Tests for EvaluationParams schema.""" + + def test_default_params(self): + """Test default parameter values.""" + params = EvaluationParams() + + assert params.analyze_pseudocode == True + assert params.require_time_complexity == True + assert params.require_space_complexity == True + assert params.partial_credit == True + assert params.complexity_equivalence == True + assert params.show_detailed_feedback == True + + def test_custom_params(self): + """Test creating custom parameters.""" + params = EvaluationParams( + analyze_pseudocode=False, + require_time_complexity=True, + require_space_complexity=False, + partial_credit=False, + time_weight=1.0, + space_weight=0.0 + ) + + assert params.analyze_pseudocode == False + assert params.time_weight == 1.0 + + def test_weight_validation(self): + """Test that weights are validated.""" + # Valid + params = EvaluationParams(time_weight=0.5, space_weight=0.5) + assert params.time_weight == 0.5 + + # Invalid + with pytest.raises(ValidationError): + EvaluationParams(time_weight=2.0) + + def test_max_nesting_depth_validation(self): + """Test max nesting depth validation.""" + # Valid + params = EvaluationParams(max_nesting_depth=20) + assert params.max_nesting_depth == 20 + + # Too low + with pytest.raises(ValidationError): + EvaluationParams(max_nesting_depth=0) + + # Too high + with pytest.raises(ValidationError): + EvaluationParams(max_nesting_depth=100) + + def test_timeout_validation(self): + """Test timeout validation.""" + # Valid + params = EvaluationParams(timeout_seconds=10.0) + assert params.timeout_seconds == 10.0 + + # Too low + with pytest.raises(ValidationError): + EvaluationParams(timeout_seconds=0.01) + + # Too high + with pytest.raises(ValidationError): + EvaluationParams(timeout_seconds=120.0) + + def test_pseudocode_style_options(self): + """Test pseudocode style options.""" + styles = ["auto", "python", "pascal", "c"] + + for style in styles: + params = EvaluationParams(pseudocode_style=style) + assert params.pseudocode_style == style + + +class TestFeedbackItem: + """Tests for FeedbackItem schema.""" + + def test_create_feedback_item(self): + """Test creating feedback item.""" + item = FeedbackItem( + level=FeedbackLevel.WARNING, + message="Time complexity is incorrect" + ) + + assert item.level == FeedbackLevel.WARNING + assert item.message == "Time complexity is incorrect" + + def test_feedback_levels(self): + """Test all feedback levels.""" + levels = [ + FeedbackLevel.INFO, + FeedbackLevel.SUCCESS, + FeedbackLevel.WARNING, + FeedbackLevel.ERROR, + FeedbackLevel.HINT + ] + + for level in levels: + item = FeedbackItem(level=level, message="Test") + assert item.level == level + + def test_feedback_with_details(self): + """Test feedback item with all details.""" + item = FeedbackItem( + level=FeedbackLevel.ERROR, + message="Expected O(n²) but got O(n)", + category="time_complexity", + location="line 3", + suggestion="Consider the nested loop structure" + ) + + assert item.category == "time_complexity" + assert item.location == "line 3" + assert item.suggestion is not None + + +class TestConstructAnalysis: + """Tests for ConstructAnalysis schema.""" + + def test_create_construct_analysis(self): + """Test creating construct analysis.""" + analysis = ConstructAnalysis( + construct_type="nested_loop", + description="Two nested FOR loops", + complexity_contribution=ComplexityClass.QUADRATIC + ) + + assert analysis.construct_type == "nested_loop" + assert analysis.complexity_contribution == ComplexityClass.QUADRATIC + + def test_construct_with_details(self): + """Test construct analysis with details.""" + analysis = ConstructAnalysis( + construct_type="loop", + description="FOR loop from 1 to n", + location="lines 1-3", + complexity_contribution=ComplexityClass.LINEAR, + details={ + "iterator": "i", + "start": 1, + "end": "n", + "step": 1 + } + ) + + assert analysis.details["iterator"] == "i" + + +class TestTimeComplexityResult: + """Tests for TimeComplexityResult schema.""" + + def test_create_correct_result(self): + """Test creating correct time complexity result.""" + result = TimeComplexityResult( + is_correct=True, + student_answer="O(n^2)", + expected_answer="O(n²)", + expected_normalized=ComplexityClass.QUADRATIC, + student_normalized=ComplexityClass.QUADRATIC, + feedback="Correct!" + ) + + assert result.is_correct == True + assert result.student_answer == "O(n^2)" + + def test_create_incorrect_result(self): + """Test creating incorrect time complexity result.""" + result = TimeComplexityResult( + is_correct=False, + student_answer="O(n)", + expected_answer="O(n²)", + expected_normalized=ComplexityClass.QUADRATIC, + student_normalized=ComplexityClass.LINEAR, + detected_complexity="O(n²)", + feedback="Your answer O(n) differs from expected O(n²)" + ) + + assert result.is_correct == False + assert result.detected_complexity == "O(n²)" + + +class TestSpaceComplexityResult: + """Tests for SpaceComplexityResult schema.""" + + def test_create_space_result(self): + """Test creating space complexity result.""" + result = SpaceComplexityResult( + is_correct=True, + student_answer="O(1)", + expected_answer="O(1)", + expected_normalized=ComplexityClass.CONSTANT, + student_normalized=ComplexityClass.CONSTANT + ) + + assert result.is_correct == True + + +class TestParseResult: + """Tests for ParseResult schema.""" + + def test_successful_parse(self): + """Test successful parse result.""" + result = ParseResult( + success=True, + errors=[], + warnings=[], + normalized_code="for i = 1 to n do\n print(i)" + ) + + assert result.success == True + assert len(result.errors) == 0 + + def test_failed_parse(self): + """Test failed parse result.""" + result = ParseResult( + success=False, + errors=["Syntax error at line 3", "Unexpected token"], + warnings=["Inconsistent indentation"] + ) + + assert result.success == False + assert len(result.errors) == 2 + assert len(result.warnings) == 1 + + +class TestComplexityAnalysis: + """Tests for ComplexityAnalysis schema.""" + + def test_create_analysis(self): + """Test creating complexity analysis.""" + tc = TimeComplexity( + overall=ComplexityClass.QUADRATIC, + expression="O(n²)" + ) + sc = SpaceComplexity( + overall=ComplexityClass.CONSTANT, + expression="O(1)" + ) + + analysis = ComplexityAnalysis( + time_complexity=tc, + space_complexity=sc, + constructs=[], + confidence=0.9 + ) + + assert analysis.confidence == 0.9 + + def test_analysis_with_constructs(self): + """Test analysis with detected constructs.""" + tc = TimeComplexity( + overall=ComplexityClass.QUADRATIC, + expression="O(n²)" + ) + sc = SpaceComplexity( + overall=ComplexityClass.CONSTANT, + expression="O(1)" + ) + + constructs = [ + ConstructAnalysis( + construct_type="nested_loop", + complexity_contribution=ComplexityClass.QUADRATIC + ) + ] + + analysis = ComplexityAnalysis( + time_complexity=tc, + space_complexity=sc, + constructs=constructs, + algorithm_type="iteration" + ) + + assert len(analysis.constructs) == 1 + assert analysis.algorithm_type == "iteration" + + +class TestEvaluationResult: + """Tests for EvaluationResult schema.""" + + def test_create_correct_result(self): + """Test creating correct evaluation result.""" + result = EvaluationResult( + is_correct=True, + score=1.0, + feedback="Excellent! All answers correct." + ) + + assert result.is_correct == True + assert result.score == 1.0 + + def test_create_partial_result(self): + """Test creating partial credit result.""" + time_result = TimeComplexityResult( + is_correct=True, + student_answer="O(n²)", + expected_answer="O(n²)", + expected_normalized=ComplexityClass.QUADRATIC + ) + space_result = SpaceComplexityResult( + is_correct=False, + student_answer="O(n)", + expected_answer="O(1)", + expected_normalized=ComplexityClass.CONSTANT + ) + + result = EvaluationResult( + is_correct=False, + time_complexity_result=time_result, + space_complexity_result=space_result, + score=0.5, + feedback="Time complexity correct, but space complexity is incorrect." + ) + + assert result.score == 0.5 + assert result.time_complexity_result.is_correct == True + assert result.space_complexity_result.is_correct == False + + def test_create_result_with_feedback_items(self): + """Test creating result with feedback items.""" + feedback_items = [ + FeedbackItem(level=FeedbackLevel.SUCCESS, message="Time complexity correct"), + FeedbackItem(level=FeedbackLevel.ERROR, message="Space complexity incorrect"), + FeedbackItem(level=FeedbackLevel.HINT, message="Consider the auxiliary array") + ] + + result = EvaluationResult( + is_correct=False, + score=0.5, + feedback="Partial credit awarded", + feedback_items=feedback_items + ) + + assert len(result.feedback_items) == 3 + + def test_create_result_with_warnings_errors(self): + """Test creating result with warnings and errors.""" + result = EvaluationResult( + is_correct=True, + score=1.0, + warnings=["Could not fully parse line 5"], + errors=[] + ) + + assert len(result.warnings) == 1 + assert len(result.errors) == 0 + + def test_to_lambda_feedback_response(self): + """Test conversion to Lambda Feedback response format.""" + time_result = TimeComplexityResult( + is_correct=True, + student_answer="O(n)", + expected_answer="O(n)", + expected_normalized=ComplexityClass.LINEAR, + feedback="Correct!" + ) + + result = EvaluationResult( + is_correct=True, + time_complexity_result=time_result, + score=1.0, + feedback="All correct!" + ) + + response = result.to_lambda_feedback_response() + + assert response["is_correct"] == True + assert response["feedback"] == "All correct!" + assert "time_complexity" in response + + def test_score_validation(self): + """Test that score is validated to be between 0 and 1.""" + # Valid + result = EvaluationResult(is_correct=True, score=0.5) + assert result.score == 0.5 + + # Invalid - too high + with pytest.raises(ValidationError): + EvaluationResult(is_correct=True, score=1.5) + + # Invalid - too low + with pytest.raises(ValidationError): + EvaluationResult(is_correct=True, score=-0.1) + + def test_result_with_analysis(self): + """Test result with complexity analysis.""" + tc = TimeComplexity( + overall=ComplexityClass.LINEAR, + expression="O(n)" + ) + sc = SpaceComplexity( + overall=ComplexityClass.CONSTANT, + expression="O(1)" + ) + + analysis = ComplexityAnalysis( + time_complexity=tc, + space_complexity=sc, + constructs=[] + ) + + result = EvaluationResult( + is_correct=True, + score=1.0, + analysis=analysis + ) + + assert result.analysis is not None + assert result.analysis.time_complexity.overall == ComplexityClass.LINEAR + + def test_result_with_metadata(self): + """Test result with custom metadata.""" + result = EvaluationResult( + is_correct=True, + score=1.0, + metadata={ + "parse_time_ms": 15, + "analysis_time_ms": 25, + "total_lines": 10 + } + ) + + assert result.metadata["parse_time_ms"] == 15 + + +class TestSchemaRoundTrip: + """Tests for schema serialization/deserialization round trips.""" + + def test_student_response_roundtrip(self): + """Test StudentResponse serialization round trip.""" + original = StudentResponse( + pseudocode="FOR i = 1 TO n DO\n x = x + 1\nEND FOR", + time_complexity="O(n)", + space_complexity="O(1)" + ) + + # Serialize and deserialize + data = original.model_dump() + restored = StudentResponse(**data) + + assert restored.pseudocode == original.pseudocode + assert restored.time_complexity == original.time_complexity + + def test_expected_answer_roundtrip(self): + """Test ExpectedAnswer serialization round trip.""" + original = ExpectedAnswer( + expected_time_complexity="O(n²)", + expected_space_complexity="O(1)", + acceptable_time_alternatives=["O(n^2)"] + ) + + data = original.model_dump() + restored = ExpectedAnswer(**data) + + assert restored.expected_time_complexity == original.expected_time_complexity + assert len(restored.acceptable_time_alternatives) == 1 + + def test_evaluation_result_roundtrip(self): + """Test EvaluationResult serialization round trip.""" + original = EvaluationResult( + is_correct=True, + score=0.75, + feedback="Good work!", + feedback_items=[ + FeedbackItem(level=FeedbackLevel.SUCCESS, message="Correct") + ] + ) + + data = original.model_dump() + restored = EvaluationResult(**data) + + assert restored.is_correct == original.is_correct + assert restored.score == original.score + assert len(restored.feedback_items) == 1 diff --git a/evaluation_function/tests/test_integration.py b/evaluation_function/tests/test_integration.py new file mode 100644 index 0000000..0bafc7a --- /dev/null +++ b/evaluation_function/tests/test_integration.py @@ -0,0 +1,620 @@ +""" +Integration tests for the complete evaluation pipeline. + +Tests cover: +- End-to-end parsing and analysis +- Various algorithm complexities +- Different pseudocode styles +- Error handling scenarios +- Edge cases +""" + +import pytest +from ..parser.parser import PseudocodeParser +from ..parser.preprocessor import Preprocessor +from ..schemas.complexity import ComplexityClass +from ..schemas.ast_nodes import NodeType, LoopType + + +class TestEndToEndParsing: + """End-to-end tests for parsing pipeline.""" + + def test_parse_simple_assignment(self, parser): + """Test parsing and analyzing simple assignment.""" + code = "x = 1" + result = parser.parse(code) + + assert result is not None + assert result.normalized_code is not None + + def test_parse_simple_loop(self, parser): + """Test parsing and analyzing simple loop.""" + code = """FOR i = 1 TO n DO + print(i) +END FOR""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + assert structure['loop_count'] >= 1 + + def test_parse_nested_loops(self, parser): + """Test parsing and analyzing nested loops.""" + code = """FOR i = 1 TO n DO + FOR j = 1 TO n DO + sum = sum + A[i][j] + END FOR +END FOR""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + assert structure['has_nested_loops'] + + def test_parse_function_with_loop(self, parser): + """Test parsing function containing loop.""" + code = """FUNCTION sum(A, n) + total = 0 + FOR i = 1 TO n DO + total = total + A[i] + END FOR + RETURN total +END FUNCTION""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + + +class TestAlgorithmComplexities: + """Tests for various algorithm complexities.""" + + def test_constant_complexity(self, parser): + """Test O(1) constant complexity detection.""" + code = """x = 1 +y = 2 +z = x + y +RETURN z""" + + structure = parser.detect_structure(code) + assert not structure['has_loops'] + assert not structure['has_recursion'] + + def test_linear_complexity(self, parser, linear_search): + """Test O(n) linear complexity detection.""" + structure = parser.detect_structure(linear_search) + + assert structure['has_loops'] + assert structure['loop_count'] == 1 + assert not structure['has_nested_loops'] + + def test_quadratic_complexity(self, parser, bubble_sort): + """Test O(n²) quadratic complexity detection.""" + structure = parser.detect_structure(bubble_sort) + + assert structure['has_loops'] + assert structure['has_nested_loops'] + assert structure['loop_count'] >= 2 + + def test_cubic_complexity(self, parser, matrix_multiplication): + """Test O(n³) cubic complexity detection.""" + structure = parser.detect_structure(matrix_multiplication) + + assert structure['has_loops'] + assert structure['has_nested_loops'] + assert structure['loop_count'] >= 3 + + def test_logarithmic_complexity(self, parser, binary_search): + """Test O(log n) logarithmic complexity detection.""" + structure = parser.detect_structure(binary_search) + + assert structure['has_loops'] + assert structure['has_conditionals'] + + def test_linearithmic_complexity(self, parser, merge_sort): + """Test O(n log n) linearithmic complexity detection.""" + structure = parser.detect_structure(merge_sort) + + assert structure['has_recursion'] + + def test_exponential_complexity(self, parser, recursive_fibonacci): + """Test O(2^n) exponential complexity detection.""" + structure = parser.detect_structure(recursive_fibonacci) + + assert structure['has_recursion'] + + +class TestPseudocodeStyles: + """Tests for different pseudocode style variations.""" + + def test_pascal_style(self, parser, pascal_style_loop): + """Test Pascal-style pseudocode.""" + result = parser.parse(pascal_style_loop) + + structure = parser.detect_structure(pascal_style_loop) + assert structure['has_loops'] + + def test_python_style(self, parser, python_style_loop): + """Test Python-style pseudocode.""" + result = parser.parse(python_style_loop) + + structure = parser.detect_structure(python_style_loop) + assert structure['has_loops'] + + def test_mixed_case_keywords(self, parser, mixed_case_keywords): + """Test mixed case keywords handling.""" + result = parser.parse(mixed_case_keywords) + + # Preprocessing should normalize case + assert result.normalized_code is not None + structure = parser.detect_structure(mixed_case_keywords) + assert structure['has_loops'] + + def test_unicode_operators(self, parser, unicode_operators): + """Test unicode operator handling.""" + result = parser.parse(unicode_operators) + + # Preprocessing should normalize operators + assert "←" not in result.normalized_code + assert "≤" not in result.normalized_code + + +class TestLoopVariations: + """Tests for various loop construct variations.""" + + def test_for_loop_variations(self, parser): + """Test different FOR loop syntaxes.""" + variations = [ + "FOR i = 1 TO n DO\n x = x + 1\nEND FOR", + "for i = 1 to n do\n x = x + 1\nend for", + "FOR i := 1 TO n DO\n x = x + 1\nEND FOR", + "FOR i ← 1 TO n DO\n x = x + 1\nEND FOR", + ] + + for code in variations: + result = parser.parse(code) + structure = parser.detect_structure(code) + assert structure['has_loops'], f"Failed to detect loop in: {code[:30]}..." + + def test_while_loop_variations(self, parser): + """Test different WHILE loop syntaxes.""" + variations = [ + "WHILE i < n DO\n i = i + 1\nEND WHILE", + "while i < n do\n i = i + 1\nend while", + "WHILE (i < n)\n i = i + 1\nEND WHILE", + ] + + for code in variations: + result = parser.parse(code) + structure = parser.detect_structure(code) + assert structure['has_loops'] + + def test_foreach_variations(self, parser): + """Test different FOR-EACH loop syntaxes.""" + variations = [ + "FOR EACH item IN list DO\n print(item)\nEND FOR", + "FOR item IN list DO\n print(item)\nEND FOR", + "for each x in array do\n process(x)\nend for", + ] + + for code in variations: + result = parser.parse(code) + structure = parser.detect_structure(code) + assert structure['has_loops'] + + def test_repeat_until_variations(self, parser): + """Test REPEAT-UNTIL loop syntaxes.""" + code = """REPEAT + x = x + 1 +UNTIL x >= n""" + + result = parser.parse(code) + structure = parser.detect_structure(code) + assert structure['has_loops'] + + +class TestConditionalVariations: + """Tests for various conditional construct variations.""" + + def test_if_variations(self, parser): + """Test different IF statement syntaxes.""" + variations = [ + "IF x > 0 THEN\n y = 1\nEND IF", + "if x > 0 then\n y = 1\nend if", + "IF x > 0:\n y = 1\nEND IF", + "IF (x > 0) THEN\n y = 1\nENDIF", + ] + + for code in variations: + result = parser.parse(code) + structure = parser.detect_structure(code) + assert structure['has_conditionals'] + + def test_if_else_variations(self, parser): + """Test IF-ELSE variations.""" + code = """IF x > 0 THEN + y = 1 +ELSE + y = -1 +END IF""" + + result = parser.parse(code) + structure = parser.detect_structure(code) + assert structure['has_conditionals'] + + def test_nested_conditionals(self, parser): + """Test nested conditionals.""" + code = """IF x > 0 THEN + IF y > 0 THEN + z = 1 + ELSE + z = 2 + END IF +ELSE + z = 3 +END IF""" + + result = parser.parse(code) + structure = parser.detect_structure(code) + assert structure['has_conditionals'] + + +class TestFunctionVariations: + """Tests for various function definition variations.""" + + def test_function_keywords(self, parser): + """Test different function definition keywords.""" + keywords = ["FUNCTION", "function", "ALGORITHM", "algorithm", + "PROCEDURE", "procedure", "DEF", "def"] + + for kw in keywords: + code = f"""{kw} test(x) + RETURN x * 2 +END {kw.upper() if kw.isupper() else 'FUNCTION'}""" + + result = parser.parse(code) + assert result is not None + + def test_function_with_parameters(self, parser): + """Test functions with various parameter styles.""" + variations = [ + "FUNCTION test(a, b)\n RETURN a + b\nEND FUNCTION", + "FUNCTION test(A[1..n])\n RETURN A[1]\nEND FUNCTION", + "FUNCTION test(x: INTEGER)\n RETURN x\nEND FUNCTION", + ] + + for code in variations: + result = parser.parse(code) + assert result is not None + + +class TestRecursionDetection: + """Tests for recursion detection.""" + + def test_simple_recursion(self, parser): + """Test simple recursion detection.""" + code = """FUNCTION factorial(n) + IF n <= 1 THEN + RETURN 1 + END IF + RETURN n * factorial(n-1) +END FUNCTION""" + + structure = parser.detect_structure(code) + assert structure['has_recursion'] + + def test_binary_recursion(self, parser): + """Test binary recursion detection (Fibonacci).""" + code = """FUNCTION fib(n) + IF n <= 1 THEN + RETURN n + END IF + RETURN fib(n-1) + fib(n-2) +END FUNCTION""" + + structure = parser.detect_structure(code) + assert structure['has_recursion'] + + def test_divide_and_conquer_recursion(self, parser): + """Test divide-and-conquer recursion detection.""" + code = """FUNCTION mergeSort(A, left, right) + IF left < right THEN + mid = (left + right) / 2 + mergeSort(A, left, mid) + mergeSort(A, mid+1, right) + merge(A, left, mid, right) + END IF +END FUNCTION""" + + structure = parser.detect_structure(code) + assert structure['has_recursion'] + + def test_no_recursion(self, parser): + """Test that non-recursive code is not flagged.""" + code = """FUNCTION sum(A, n) + total = 0 + FOR i = 1 TO n DO + total = total + A[i] + END FOR + RETURN total +END FUNCTION""" + + structure = parser.detect_structure(code) + assert not structure['has_recursion'] + + +class TestErrorHandling: + """Tests for error handling scenarios.""" + + def test_malformed_loop(self, parser): + """Test handling of malformed loop.""" + code = """FOR i = 1 TO + x = x + 1 +END FOR""" + + result = parser.parse(code) + # Should not crash, may have errors or use fallback + assert result is not None + + def test_unclosed_block(self, parser): + """Test handling of unclosed block.""" + code = """IF x > 0 THEN + y = 1 + # Missing END IF""" + + result = parser.parse(code) + assert result is not None + + def test_mismatched_keywords(self, parser): + """Test handling of mismatched keywords.""" + code = """FOR i = 1 TO n DO + x = x + 1 +END WHILE""" # Mismatched: FOR with END WHILE + + result = parser.parse(code) + assert result is not None + + def test_unknown_constructs(self, parser): + """Test handling of unknown constructs.""" + code = """MYSTERY_KEYWORD x = 1 +ANOTHER_WEIRD_THING y = 2""" + + result = parser.parse(code) + assert result is not None + + def test_empty_blocks(self, parser): + """Test handling of empty blocks.""" + code = """IF x > 0 THEN +END IF + +FOR i = 1 TO n DO +END FOR""" + + result = parser.parse(code) + assert result is not None + + +class TestPreprocessorIntegration: + """Tests for preprocessor integration.""" + + def test_typo_correction_in_pipeline(self, parser): + """Test that typos are corrected during parsing.""" + code = """WHLIE i < n DO + i = i + 1 +END WHLIE""" + + result = parser.parse(code) + + # Preprocessor should fix "WHLIE" to "WHILE" + assert "while" in result.normalized_code.lower() or len(result.warnings) > 0 + + def test_operator_normalization_in_pipeline(self, parser): + """Test that operators are normalized during parsing.""" + code = """x ← 1 +IF a ≤ b THEN + y ← 2 +END IF""" + + result = parser.parse(code) + + # Operators should be normalized + assert "←" not in result.normalized_code + + def test_case_normalization_in_pipeline(self, parser): + """Test that keywords are case-normalized.""" + code = """FOR i = 1 To n Do + PRINT(i) +End FOR""" + + result = parser.parse(code) + + # Should detect loop regardless of case + structure = parser.detect_structure(code) + assert structure['has_loops'] + + +class TestComplexAlgorithms: + """Tests for complex algorithm parsing.""" + + def test_quicksort(self, parser): + """Test parsing quicksort algorithm.""" + code = """FUNCTION quickSort(A, low, high) + IF low < high THEN + pivot = partition(A, low, high) + quickSort(A, low, pivot - 1) + quickSort(A, pivot + 1, high) + END IF +END FUNCTION + +FUNCTION partition(A, low, high) + pivot = A[high] + i = low - 1 + FOR j = low TO high - 1 DO + IF A[j] <= pivot THEN + i = i + 1 + swap(A[i], A[j]) + END IF + END FOR + swap(A[i + 1], A[high]) + RETURN i + 1 +END FUNCTION""" + + result = parser.parse(code) + structure = parser.detect_structure(code) + + assert structure['has_loops'] + assert structure['has_recursion'] + assert structure['has_conditionals'] + + def test_dijkstra(self, parser): + """Test parsing Dijkstra's algorithm.""" + code = """FUNCTION dijkstra(G, source) + dist[source] = 0 + FOR EACH vertex v IN G DO + IF v != source THEN + dist[v] = INFINITY + END IF + add v to Q + END FOR + + WHILE Q is not empty DO + u = vertex in Q with min dist[u] + remove u from Q + + FOR EACH neighbor v of u DO + alt = dist[u] + length(u, v) + IF alt < dist[v] THEN + dist[v] = alt + END IF + END FOR + END WHILE + + RETURN dist +END FUNCTION""" + + result = parser.parse(code) + structure = parser.detect_structure(code) + + assert structure['has_loops'] + assert structure['has_nested_loops'] + assert structure['has_conditionals'] + + def test_dfs(self, parser): + """Test parsing DFS algorithm.""" + code = """FUNCTION DFS(G, v, visited) + visited[v] = true + print(v) + + FOR EACH neighbor u of v DO + IF NOT visited[u] THEN + DFS(G, u, visited) + END IF + END FOR +END FUNCTION""" + + result = parser.parse(code) + structure = parser.detect_structure(code) + + assert structure['has_loops'] + assert structure['has_recursion'] + assert structure['has_conditionals'] + + +class TestEdgeCases: + """Tests for edge cases.""" + + def test_deeply_nested_structure(self, parser, deeply_nested): + """Test deeply nested loops.""" + result = parser.parse(deeply_nested) + structure = parser.detect_structure(deeply_nested) + + assert structure['has_loops'] + assert structure['has_nested_loops'] + assert structure['loop_count'] >= 5 + + def test_very_long_code(self, parser): + """Test parsing very long code.""" + # Generate 200 lines of code + lines = [] + lines.append("FUNCTION longFunction(n)") + lines.append(" x = 0") + for i in range(100): + lines.append(f" x = x + {i}") + lines.append(" FOR i = 1 TO n DO") + lines.append(" y = y + 1") + lines.append(" END FOR") + lines.append(" RETURN x") + lines.append("END FUNCTION") + + code = "\n".join(lines) + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + + def test_single_line_constructs(self, parser): + """Test single-line constructs.""" + code = "IF x > 0 THEN y = 1" + + result = parser.parse(code) + structure = parser.detect_structure(code) + assert structure['has_conditionals'] + + def test_multiple_functions(self, parser, multiple_functions): + """Test multiple function definitions.""" + result = parser.parse(multiple_functions) + + assert result is not None + + def test_unicode_identifiers(self, parser): + """Test code with unicode identifiers.""" + code = """σ = 0 +FOR i = 1 TO n DO + σ = σ + A[i] +END FOR""" + + result = parser.parse(code) + # Should not crash + assert result is not None + + def test_mixed_loops_and_recursion(self, parser): + """Test code with both loops and recursion.""" + code = """FUNCTION process(A, n) + FOR i = 1 TO n DO + IF A[i] > 0 THEN + result = process(A, A[i]) + END IF + END FOR + RETURN result +END FUNCTION""" + + result = parser.parse(code) + structure = parser.detect_structure(code) + + assert structure['has_loops'] + assert structure['has_recursion'] + assert structure['has_conditionals'] + + +class TestComplexityTestCases: + """Tests using predefined complexity test cases.""" + + def test_complexity_cases(self, parser, complexity_test_cases): + """Test parsing of all complexity test cases.""" + for test_case in complexity_test_cases: + code = test_case["code"] + expected_time = test_case["expected_time"] + + result = parser.parse(code) + + # Should parse successfully + assert result is not None, f"Failed to parse {test_case['name']}" + + # Structure detection should work + structure = parser.detect_structure(code) + + # Basic sanity checks based on expected complexity + if expected_time == ComplexityClass.CONSTANT: + assert not structure['has_loops'], f"{test_case['name']} shouldn't have loops" + elif expected_time in [ComplexityClass.LINEAR, ComplexityClass.QUADRATIC, ComplexityClass.CUBIC]: + assert structure['has_loops'], f"{test_case['name']} should have loops" diff --git a/evaluation_function/tests/test_parser.py b/evaluation_function/tests/test_parser.py new file mode 100644 index 0000000..09aa96c --- /dev/null +++ b/evaluation_function/tests/test_parser.py @@ -0,0 +1,726 @@ +""" +Comprehensive tests for the Parser module. + +Tests cover: +- Basic parsing functionality +- Loop parsing (for, while, repeat, foreach) +- Conditional parsing (if/else/elif) +- Function parsing +- Expression parsing +- Error handling and fallback +- Structure detection +""" + +import pytest +from ..parser.parser import PseudocodeParser, ParseError, ParserConfig +from ..schemas.ast_nodes import ( + ProgramNode, FunctionNode, BlockNode, LoopNode, ConditionalNode, + AssignmentNode, ReturnNode, FunctionCallNode, VariableNode, + LiteralNode, BinaryOpNode, LoopType, NodeType +) + + +class TestBasicParsing: + """Tests for basic parsing functionality.""" + + def test_parse_returns_parse_result(self, parser): + """Test that parse returns a ParseResult object.""" + result = parser.parse("x = 1") + + assert hasattr(result, 'success') + assert hasattr(result, 'ast') + assert hasattr(result, 'errors') + assert hasattr(result, 'warnings') + + def test_parse_simple_assignment(self, parser): + """Test parsing simple assignment.""" + result = parser.parse("x = 1") + + assert result.success or len(result.errors) > 0 # May use fallback + + def test_parse_empty_input(self, parser): + """Test parsing empty input.""" + result = parser.parse("") + + # Should handle gracefully + assert result is not None + + def test_parse_whitespace_only(self, parser): + """Test parsing whitespace-only input.""" + result = parser.parse(" \n\n ") + + assert result is not None + + def test_normalized_code_returned(self, parser): + """Test that normalized code is included in result.""" + result = parser.parse("FOR i = 1 TO n DO\n print(i)\nEND FOR") + + assert result.normalized_code is not None + + +class TestForLoopParsing: + """Tests for FOR loop parsing.""" + + def test_parse_simple_for_loop(self, parser, simple_for_loop): + """Test parsing simple FOR loop.""" + result = parser.parse(simple_for_loop) + + # Check structure was detected + structure = parser.detect_structure(simple_for_loop) + assert structure['has_loops'] + assert structure['loop_count'] >= 1 + + def test_parse_for_loop_with_range(self, parser): + """Test parsing FOR loop with numeric range.""" + code = """FOR i = 1 TO 10 DO + print(i) +END FOR""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + + def test_parse_for_loop_with_step(self, parser): + """Test parsing FOR loop with step.""" + code = """FOR i = 1 TO n STEP 2 DO + print(i) +END FOR""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + + def test_parse_for_loop_downto(self, parser): + """Test parsing FOR loop with DOWNTO.""" + code = """FOR i = n DOWNTO 1 DO + print(i) +END FOR""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + + def test_parse_nested_for_loops(self, parser, nested_for_loops): + """Test parsing nested FOR loops.""" + result = parser.parse(nested_for_loops) + + structure = parser.detect_structure(nested_for_loops) + assert structure['has_loops'] + assert structure['has_nested_loops'] + assert structure['loop_count'] >= 2 + + def test_parse_triple_nested_loops(self, parser, triple_nested_loops): + """Test parsing triple nested loops.""" + result = parser.parse(triple_nested_loops) + + structure = parser.detect_structure(triple_nested_loops) + assert structure['has_loops'] + assert structure['has_nested_loops'] + assert structure['loop_count'] >= 3 + + +class TestWhileLoopParsing: + """Tests for WHILE loop parsing.""" + + def test_parse_simple_while_loop(self, parser, while_loop): + """Test parsing simple WHILE loop.""" + result = parser.parse(while_loop) + + structure = parser.detect_structure(while_loop) + assert structure['has_loops'] + + def test_parse_while_with_complex_condition(self, parser): + """Test parsing WHILE with complex condition.""" + code = """WHILE i < n AND j > 0 DO + i = i + 1 + j = j - 1 +END WHILE""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + + def test_parse_nested_while_loops(self, parser): + """Test parsing nested WHILE loops.""" + code = """WHILE i < n DO + WHILE j < m DO + x = x + 1 + j = j + 1 + END WHILE + i = i + 1 +END WHILE""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + assert structure['loop_count'] >= 2 + + +class TestRepeatUntilParsing: + """Tests for REPEAT-UNTIL loop parsing.""" + + def test_parse_repeat_until(self, parser, repeat_until_loop): + """Test parsing REPEAT-UNTIL loop.""" + result = parser.parse(repeat_until_loop) + + structure = parser.detect_structure(repeat_until_loop) + assert structure['has_loops'] + + def test_parse_repeat_with_complex_body(self, parser): + """Test parsing REPEAT with complex body.""" + code = """REPEAT + x = x + 1 + y = y * 2 + IF x > 10 THEN + z = z + 1 + END IF +UNTIL x >= n""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + assert structure['has_conditionals'] + + +class TestForEachParsing: + """Tests for FOR-EACH loop parsing.""" + + def test_parse_foreach_loop(self, parser, foreach_loop): + """Test parsing FOR-EACH loop.""" + result = parser.parse(foreach_loop) + + structure = parser.detect_structure(foreach_loop) + assert structure['has_loops'] + + def test_parse_foreach_variations(self, parser): + """Test parsing FOR-EACH variations.""" + variations = [ + "FOR EACH item IN list DO\n print(item)\nEND FOR", + "FOR item IN list DO\n print(item)\nEND FOR", + "for each x in array do\n process(x)\nend for", + ] + + for code in variations: + result = parser.parse(code) + structure = parser.detect_structure(code) + assert structure['has_loops'] + + +class TestConditionalParsing: + """Tests for conditional (IF/ELSE) parsing.""" + + def test_parse_simple_if(self, parser): + """Test parsing simple IF statement.""" + code = """IF x > 0 THEN + print(x) +END IF""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_conditionals'] + + def test_parse_if_else(self, parser): + """Test parsing IF-ELSE statement.""" + code = """IF x > 0 THEN + print("positive") +ELSE + print("non-positive") +END IF""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_conditionals'] + + def test_parse_if_elif_else(self, parser): + """Test parsing IF-ELIF-ELSE statement.""" + code = """IF x > 0 THEN + print("positive") +ELIF x < 0 THEN + print("negative") +ELSE + print("zero") +END IF""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_conditionals'] + + def test_parse_nested_conditionals(self, parser): + """Test parsing nested conditionals.""" + code = """IF x > 0 THEN + IF y > 0 THEN + print("both positive") + ELSE + print("x positive, y non-positive") + END IF +ELSE + print("x non-positive") +END IF""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_conditionals'] + + +class TestFunctionParsing: + """Tests for function definition parsing.""" + + def test_parse_simple_function(self, parser): + """Test parsing simple function definition.""" + code = """FUNCTION test() + RETURN 1 +END FUNCTION""" + result = parser.parse(code) + + if result.success and result.ast: + assert len(result.ast.functions) > 0 or result.ast.global_statements is not None + + def test_parse_function_with_parameters(self, parser): + """Test parsing function with parameters.""" + code = """FUNCTION add(a, b) + RETURN a + b +END FUNCTION""" + result = parser.parse(code) + + # Should parse without errors or use fallback + assert result is not None + + def test_parse_function_with_array_parameter(self, parser): + """Test parsing function with array parameter.""" + code = """FUNCTION sum(A[1..n]) + total = 0 + FOR i = 1 TO n DO + total = total + A[i] + END FOR + RETURN total +END FUNCTION""" + result = parser.parse(code) + + assert result is not None + + def test_parse_multiple_functions(self, parser, multiple_functions): + """Test parsing multiple function definitions.""" + result = parser.parse(multiple_functions) + + # Should recognize multiple functions + assert result is not None + + def test_parse_recursive_function(self, parser, recursive_fibonacci): + """Test parsing recursive function.""" + result = parser.parse(recursive_fibonacci) + + structure = parser.detect_structure(recursive_fibonacci) + assert structure['has_recursion'] + + +class TestRecursionDetection: + """Tests for recursion detection.""" + + def test_detect_simple_recursion(self, parser, recursive_factorial): + """Test detection of simple recursion.""" + structure = parser.detect_structure(recursive_factorial) + assert structure['has_recursion'] + + def test_detect_double_recursion(self, parser, recursive_fibonacci): + """Test detection of double recursion (like Fibonacci).""" + structure = parser.detect_structure(recursive_fibonacci) + assert structure['has_recursion'] + + def test_detect_divide_conquer_recursion(self, parser, merge_sort): + """Test detection of divide-and-conquer recursion.""" + structure = parser.detect_structure(merge_sort) + assert structure['has_recursion'] + + def test_no_false_recursion_detection(self, parser, linear_search): + """Test that non-recursive code is not flagged as recursive.""" + structure = parser.detect_structure(linear_search) + assert not structure['has_recursion'] + + +class TestExpressionParsing: + """Tests for expression parsing.""" + + def test_parse_arithmetic_expressions(self, parser): + """Test parsing arithmetic expressions.""" + code = "x = a + b * c - d / e" + result = parser.parse(code) + + assert result is not None + + def test_parse_comparison_expressions(self, parser): + """Test parsing comparison expressions.""" + expressions = [ + "IF a == b THEN x = 1 END IF", + "IF a != b THEN x = 1 END IF", + "IF a < b THEN x = 1 END IF", + "IF a <= b THEN x = 1 END IF", + "IF a > b THEN x = 1 END IF", + "IF a >= b THEN x = 1 END IF", + ] + + for code in expressions: + result = parser.parse(code) + assert result is not None + + def test_parse_logical_expressions(self, parser): + """Test parsing logical expressions.""" + code = "IF a AND b OR NOT c THEN x = 1 END IF" + result = parser.parse(code) + + assert result is not None + + def test_parse_array_access(self, parser): + """Test parsing array access expressions.""" + code = """x = A[i] +y = B[i][j] +z = C[i + 1][j - 1]""" + result = parser.parse(code) + + assert result is not None + + def test_parse_function_call_expression(self, parser): + """Test parsing function call expressions.""" + code = """x = max(a, b) +y = min(c, d) +z = sqrt(x * x + y * y)""" + result = parser.parse(code) + + assert result is not None + + +class TestStructureDetection: + """Tests for structure detection.""" + + def test_detect_no_loops(self, parser): + """Test structure detection with no loops.""" + code = """x = 1 +y = 2 +z = x + y""" + structure = parser.detect_structure(code) + + assert not structure['has_loops'] + assert structure['loop_count'] == 0 + + def test_detect_single_loop(self, parser, simple_for_loop): + """Test structure detection with single loop.""" + structure = parser.detect_structure(simple_for_loop) + + assert structure['has_loops'] + assert structure['loop_count'] >= 1 + assert not structure['has_nested_loops'] or structure['max_nesting'] == 1 + + def test_detect_nested_loops(self, parser, nested_for_loops): + """Test structure detection with nested loops.""" + structure = parser.detect_structure(nested_for_loops) + + assert structure['has_loops'] + assert structure['has_nested_loops'] + assert structure['loop_count'] >= 2 + + def test_detect_conditionals(self, parser): + """Test structure detection with conditionals.""" + code = """IF x > 0 THEN + y = 1 +ELSE + y = -1 +END IF""" + structure = parser.detect_structure(code) + + assert structure['has_conditionals'] + + def test_detect_complex_structure(self, parser, bubble_sort): + """Test structure detection with complex algorithm.""" + structure = parser.detect_structure(bubble_sort) + + assert structure['has_loops'] + assert structure['has_nested_loops'] + assert structure['has_conditionals'] + + +class TestStyleVariations: + """Tests for different pseudocode style variations.""" + + def test_parse_python_style(self, parser, python_style_loop): + """Test parsing Python-style pseudocode.""" + result = parser.parse(python_style_loop) + + structure = parser.detect_structure(python_style_loop) + assert structure['has_loops'] + + def test_parse_pascal_style(self, parser, pascal_style_loop): + """Test parsing Pascal-style pseudocode.""" + result = parser.parse(pascal_style_loop) + + structure = parser.detect_structure(pascal_style_loop) + assert structure['has_loops'] + + def test_parse_mixed_case(self, parser, mixed_case_keywords): + """Test parsing mixed case keywords.""" + result = parser.parse(mixed_case_keywords) + + structure = parser.detect_structure(mixed_case_keywords) + assert structure['has_loops'] + + def test_parse_unicode_operators(self, parser, unicode_operators): + """Test parsing unicode operators.""" + result = parser.parse(unicode_operators) + + assert result is not None + + +class TestErrorHandling: + """Tests for error handling.""" + + def test_handle_syntax_error(self, parser): + """Test handling of syntax errors.""" + code = "FOR i = 1 TO" # Incomplete + result = parser.parse(code) + + # Should not crash, may have errors or use fallback + assert result is not None + + def test_handle_mismatched_blocks(self, parser): + """Test handling of mismatched block delimiters.""" + code = """FOR i = 1 TO n DO + IF x > 0 THEN + print(x) +END FOR""" # Missing END IF + result = parser.parse(code) + + assert result is not None + + def test_handle_unknown_keywords(self, parser): + """Test handling of unknown keywords.""" + code = """UNKNOWN_KEYWORD x = 1 +ANOTHER_WEIRD_THING y = 2""" + result = parser.parse(code) + + assert result is not None + + def test_fallback_on_parse_error(self, parser): + """Test that fallback parser is used on error.""" + code = """FOR i = 1 TO n DO + malformed { syntax [ here + x = x + 1 +END FOR""" + result = parser.parse(code) + + # Should use fallback and still detect loop + structure = parser.detect_structure(code) + assert structure['has_loops'] + + def test_strict_mode_no_fallback(self): + """Test that strict mode doesn't use fallback for malformed input.""" + config = ParserConfig(strict_mode=True) + parser = PseudocodeParser(config) + + code = "FOR i = 1 TO" # Incomplete + result = parser.parse(code) + + # In strict mode, may have errors or warnings, but should still return a result + # The parser is now more resilient + assert result is not None + + +class TestCompleteAlgorithms: + """Tests for parsing complete algorithms.""" + + def test_parse_binary_search(self, parser, binary_search): + """Test parsing binary search algorithm.""" + result = parser.parse(binary_search) + + structure = parser.detect_structure(binary_search) + assert structure['has_loops'] + assert structure['has_conditionals'] + + def test_parse_bubble_sort(self, parser, bubble_sort): + """Test parsing bubble sort algorithm.""" + result = parser.parse(bubble_sort) + + structure = parser.detect_structure(bubble_sort) + assert structure['has_loops'] + assert structure['has_nested_loops'] + assert structure['has_conditionals'] + + def test_parse_merge_sort(self, parser, merge_sort): + """Test parsing merge sort algorithm.""" + result = parser.parse(merge_sort) + + structure = parser.detect_structure(merge_sort) + assert structure['has_recursion'] + + def test_parse_matrix_multiplication(self, parser, matrix_multiplication): + """Test parsing matrix multiplication.""" + result = parser.parse(matrix_multiplication) + + structure = parser.detect_structure(matrix_multiplication) + assert structure['has_loops'] + assert structure['has_nested_loops'] + assert structure['loop_count'] >= 3 + + +class TestEdgeCases: + """Tests for edge cases.""" + + def test_deeply_nested_structure(self, parser, deeply_nested): + """Test parsing deeply nested structure.""" + result = parser.parse(deeply_nested) + + structure = parser.detect_structure(deeply_nested) + assert structure['has_loops'] + assert structure['has_nested_loops'] + assert structure['loop_count'] >= 5 + + def test_empty_function_body(self, parser, empty_function): + """Test parsing empty function body.""" + result = parser.parse(empty_function) + + assert result is not None + + def test_single_statement(self, parser): + """Test parsing single statement.""" + result = parser.parse("x = 1") + + assert result is not None + + def test_comments_handling(self, parser): + """Test that comments are handled.""" + code = """// This is a comment +FOR i = 1 TO n DO + # Another comment + x = x + 1 +END FOR""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + + def test_very_long_code(self, parser): + """Test parsing very long code.""" + # Generate code with many statements + statements = [f"x{i} = {i}" for i in range(100)] + code = "\n".join(statements) + + result = parser.parse(code) + + assert result is not None + + +class TestCurlyBraceBlocks: + """Tests for curly brace block syntax.""" + + def test_for_loop_with_curly_braces(self, parser): + """Test parsing FOR loop with curly braces.""" + code = """FOR i = 1 TO n { + x = x + 1 +}""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + assert structure['loop_count'] >= 1 + + def test_while_loop_with_curly_braces(self, parser): + """Test parsing WHILE loop with curly braces.""" + code = """WHILE x > 0 { + x = x - 1 +}""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + + def test_if_statement_with_curly_braces(self, parser): + """Test parsing IF statement with curly braces.""" + code = """IF x > 0 { + y = 1 +}""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_conditionals'] + + def test_function_with_curly_braces(self, parser): + """Test parsing function with curly braces.""" + code = """FUNCTION test(n) { + FOR i = 1 TO n { + x = x + 1 + } +}""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + + def test_nested_loops_with_curly_braces(self, parser): + """Test parsing nested loops with curly braces.""" + code = """FOR i = 1 TO n { + FOR j = 1 TO n { + x = x + 1 + } +}""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + assert structure['has_nested_loops'] + assert structure['loop_count'] >= 2 + + def test_mixed_end_and_braces(self, parser): + """Test mixing END keywords and curly braces.""" + code = """FOR i = 1 TO n { + IF x > 0 THEN + y = 1 + END IF +}""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + assert structure['has_conditionals'] + + +class TestCallKeyword: + """Tests for CALL keyword function invocation.""" + + def test_call_keyword_statement(self, parser): + """Test CALL keyword for function invocation.""" + code = """FOR i = 1 TO n DO + CALL print(i) +END FOR""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + + def test_call_keyword_in_function(self, parser): + """Test CALL keyword within a function.""" + code = """FUNCTION test(A, n) + FOR i = 1 TO n DO + CALL process(A[i]) + END FOR +END FUNCTION""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + + def test_direct_function_call(self, parser): + """Test direct function call without CALL keyword.""" + code = """FOR i = 1 TO n DO + print(i) +END FOR""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] + + def test_call_with_curly_braces(self, parser): + """Test CALL keyword with curly brace blocks.""" + code = """FOR i = 1 TO n { + CALL process(i) +}""" + result = parser.parse(code) + + structure = parser.detect_structure(code) + assert structure['has_loops'] diff --git a/evaluation_function/tests/test_preprocessor.py b/evaluation_function/tests/test_preprocessor.py new file mode 100644 index 0000000..7fb4f79 --- /dev/null +++ b/evaluation_function/tests/test_preprocessor.py @@ -0,0 +1,440 @@ +""" +Comprehensive tests for the Preprocessor module. + +Tests cover: +- Keyword normalization (case variations) +- Operator normalization (assignment, comparison) +- Whitespace normalization +- Typo correction +- String preservation +- Edge cases +""" + +import pytest +from ..parser.preprocessor import Preprocessor, PreprocessorConfig + + +class TestKeywordNormalization: + """Tests for keyword case normalization.""" + + def test_for_keyword_variations(self, preprocessor): + """Test FOR keyword case variations.""" + variations = ["FOR", "For", "for", "FOR", "fOr"] + + for var in variations: + code = f"{var} i = 1 to n do" + result, _ = preprocessor.preprocess(code) + assert result.startswith("for"), f"'{var}' should normalize to 'for'" + + def test_while_keyword_variations(self, preprocessor): + """Test WHILE keyword case variations.""" + variations = ["WHILE", "While", "while", "WHILST", "whilst"] + + for var in variations: + code = f"{var} i < n do" + result, _ = preprocessor.preprocess(code) + assert result.startswith("while"), f"'{var}' should normalize to 'while'" + + def test_if_then_else_variations(self, preprocessor): + """Test IF/THEN/ELSE keyword variations.""" + code = "IF condition THEN do_something ELSE do_other" + result, _ = preprocessor.preprocess(code) + + assert "if" in result + assert "then" in result + assert "else" in result + + def test_function_keyword_variations(self, preprocessor): + """Test FUNCTION keyword variations.""" + variations = [ + "FUNCTION test()", + "Function test()", + "function test()", + "ALGORITHM test()", + "Algorithm test()", + "PROCEDURE test()", + "DEF test()", + ] + + for code in variations: + result, _ = preprocessor.preprocess(code) + assert "function" in result or "algorithm" in result or "procedure" in result or "def" in result + + def test_return_keyword_variations(self, preprocessor): + """Test RETURN keyword variations.""" + variations = ["RETURN x", "Return x", "return x", "RETURNS x"] + + for code in variations: + result, _ = preprocessor.preprocess(code) + assert "return" in result + + def test_boolean_literal_normalization(self, preprocessor): + """Test TRUE/FALSE normalization.""" + code = "x = TRUE\ny = FALSE" + result, _ = preprocessor.preprocess(code) + + assert "true" in result + assert "false" in result + + def test_logical_operator_variations(self, preprocessor): + """Test AND/OR/NOT variations.""" + test_cases = [ + ("a AND b", "and"), + ("a And b", "and"), + ("a && b", "and"), # && gets replaced with " and " + ("a OR b", "or"), + ("a Or b", "or"), + ("a || b", "or"), # || gets replaced with " or " + ("NOT a", "not"), + ("Not a", "not"), + ] + + for code, expected in test_cases: + result, _ = preprocessor.preprocess(code) + # Check that the expected word appears (may have extra spaces) + assert expected in result.lower(), f"'{code}' should contain '{expected}', got '{result}'" + + def test_end_keyword_variations(self, preprocessor): + """Test END keyword variations.""" + variations = [ + "END FOR", "ENDFOR", "endfor", "End For", + "END WHILE", "ENDWHILE", "endwhile", + "END IF", "ENDIF", "endif", + "DONE", "done", + ] + + for code in variations: + result, _ = preprocessor.preprocess(code) + assert "end" in result or "done" in result + + +class TestOperatorNormalization: + """Tests for operator normalization.""" + + def test_assignment_operators(self, preprocessor): + """Test assignment operator normalization.""" + test_cases = [ + ("x := 5", "x = 5"), + ("x ← 5", "x = 5"), + ("x <- 5", "x = 5"), + ] + + for code, expected in test_cases: + result, _ = preprocessor.preprocess(code) + assert "=" in result + assert ":=" not in result + assert "←" not in result + assert "<-" not in result + + def test_unicode_comparison_operators(self, preprocessor): + """Test unicode comparison operator normalization.""" + test_cases = [ + ("a ≤ b", "<="), + ("a ≥ b", ">="), + ("a ≠ b", "!="), + ] + + for code, expected in test_cases: + result, _ = preprocessor.preprocess(code) + assert expected in result + + def test_not_equal_variations(self, preprocessor): + """Test not-equal operator variations.""" + test_cases = [ + ("a <> b", "!="), + ("a ≠ b", "!="), + ] + + for code, expected in test_cases: + result, _ = preprocessor.preprocess(code) + assert expected in result + + +class TestTypoCorrection: + """Tests for typo correction.""" + + def test_common_keyword_typos(self, preprocessor): + """Test correction of common keyword typos.""" + typos = { + "whlie": "while", + "wihle": "while", + "fro": "for", + "eles": "else", + "esle": "else", + "retrun": "return", + "reutrn": "return", + "fucntion": "function", + "funtion": "function", + "algoritm": "algorithm", + "pritn": "print", + } + + for typo, correct in typos.items(): + code = f"{typo} test" + result, warnings = preprocessor.preprocess(code) + assert correct in result.lower(), f"'{typo}' should be corrected to '{correct}', got '{result}'" + assert len(warnings) > 0, f"Warning should be generated for typo '{typo}'" + + def test_typo_warning_message(self, preprocessor): + """Test that typo corrections generate appropriate warnings.""" + code = "whlie condition do" + result, warnings = preprocessor.preprocess(code) + + assert len(warnings) > 0 + assert any("whlie" in w.lower() or "fixed" in w.lower() for w in warnings) + + def test_no_typo_correction_when_disabled(self): + """Test that typos are not corrected when disabled.""" + config = PreprocessorConfig(fix_common_typos=False) + preprocessor = Preprocessor(config) + + code = "whlie condition do" + result, warnings = preprocessor.preprocess(code) + + # Should still normalize case, but not fix typo + assert "whlie" in result.lower() or "while" not in result + + +class TestWhitespaceNormalization: + """Tests for whitespace normalization.""" + + def test_tab_to_space_conversion(self, preprocessor): + """Test tab to space conversion.""" + code = "FOR i = 1 TO n DO\n\t\tprint(i)" + result, _ = preprocessor.preprocess(code) + + assert "\t" not in result + + def test_trailing_whitespace_removal(self, preprocessor): + """Test trailing whitespace removal.""" + code = "x = 1 \ny = 2 " + result, _ = preprocessor.preprocess(code) + + for line in result.split('\n'): + assert line == line.rstrip() + + def test_multiple_blank_lines_collapse(self, preprocessor): + """Test multiple blank lines collapse to at most two.""" + code = "x = 1\n\n\n\n\ny = 2" + result, _ = preprocessor.preprocess(code) + + assert "\n\n\n" not in result + + def test_multiple_spaces_normalization(self, preprocessor): + """Test multiple spaces normalized to single space.""" + code = "x = 1" + result, _ = preprocessor.preprocess(code) + + # Leading whitespace preserved, but multiple internal spaces collapsed + assert " " not in result.strip() + + def test_indentation_preserved(self, preprocessor): + """Test that meaningful indentation is preserved.""" + code = """FOR i = 1 TO n DO + FOR j = 1 TO n DO + x = x + 1 + END FOR +END FOR""" + result, _ = preprocessor.preprocess(code) + + lines = result.split('\n') + # Check that indentation structure is preserved + assert lines[1].startswith(' ' * 4) or lines[1].startswith(' ') + + +class TestStringPreservation: + """Tests for string literal preservation.""" + + def test_double_quoted_strings_preserved(self, preprocessor): + """Test that double-quoted strings are preserved.""" + code = 'print("HELLO WORLD FOR WHILE IF")' + result, _ = preprocessor.preprocess(code) + + # Keywords inside string should NOT be normalized + assert '"HELLO WORLD FOR WHILE IF"' in result or '"hello world for while if"' in result.lower() + + def test_single_quoted_strings_preserved(self, preprocessor): + """Test that single-quoted strings are preserved.""" + code = "print('HELLO WORLD FOR WHILE IF')" + result, _ = preprocessor.preprocess(code) + + assert "'" in result + + def test_mixed_strings_and_keywords(self, preprocessor): + """Test code with both strings and keywords.""" + code = '''FOR i = 1 TO n DO + print("Processing item: FOR") +END FOR''' + result, _ = preprocessor.preprocess(code) + + # Outer FOR should be normalized, but string content preserved + assert result.startswith("for") + + +class TestIndentationDetection: + """Tests for indentation style detection.""" + + def test_detect_2_space_indent(self, preprocessor): + """Test detection of 2-space indentation.""" + code = """for i = 1 to n do + x = x + 1 + y = y + 1 +end for""" + + indent_unit = preprocessor.detect_indentation_style(code) + assert indent_unit == 2 + + def test_detect_4_space_indent(self, preprocessor): + """Test detection of 4-space indentation.""" + code = """for i = 1 to n do + x = x + 1 + y = y + 1 +end for""" + + indent_unit = preprocessor.detect_indentation_style(code) + assert indent_unit == 4 + + def test_get_indent_level(self, preprocessor): + """Test getting indent level of lines.""" + assert preprocessor.get_indent_level("x = 1", 4) == 0 + assert preprocessor.get_indent_level(" x = 1", 4) == 1 + assert preprocessor.get_indent_level(" x = 1", 4) == 2 + + +class TestEdgeCases: + """Tests for edge cases and special scenarios.""" + + def test_empty_input(self, preprocessor): + """Test handling of empty input.""" + result, warnings = preprocessor.preprocess("") + assert result == "" + + def test_only_whitespace(self, preprocessor): + """Test handling of whitespace-only input.""" + result, warnings = preprocessor.preprocess(" \n\n \t\t") + assert result.strip() == "" + + def test_single_line(self, preprocessor): + """Test handling of single line input.""" + code = "x = 1" + result, _ = preprocessor.preprocess(code) + assert result == "x = 1" + + def test_windows_line_endings(self, preprocessor): + """Test handling of Windows line endings.""" + code = "FOR i = 1 TO n DO\r\n print(i)\r\nEND FOR" + result, _ = preprocessor.preprocess(code) + + assert "\r\n" not in result + assert "\r" not in result + + def test_mac_line_endings(self, preprocessor): + """Test handling of old Mac line endings.""" + code = "FOR i = 1 TO n DO\r print(i)\rEND FOR" + result, _ = preprocessor.preprocess(code) + + assert "\r" not in result + + def test_unicode_identifiers(self, preprocessor): + """Test handling of unicode in identifiers.""" + code = "variablé = 1\nπ = 3.14" + result, _ = preprocessor.preprocess(code) + + # Should not crash on unicode + assert "=" in result + + def test_very_long_lines(self, preprocessor): + """Test handling of very long lines.""" + long_expr = "x = " + " + ".join(["a"] * 100) + result, _ = preprocessor.preprocess(long_expr) + + assert "x =" in result + + def test_nested_strings(self, preprocessor): + """Test handling of nested quotes (escaped).""" + code = r'print("He said \"hello\"")' + result, _ = preprocessor.preprocess(code) + + # Should handle escaped quotes + assert "print" in result + + +class TestPreprocessorConfig: + """Tests for preprocessor configuration options.""" + + def test_disable_case_normalization(self): + """Test disabling case normalization.""" + config = PreprocessorConfig(normalize_case=False) + preprocessor = Preprocessor(config) + + code = "FOR i = 1 TO n DO" + result, _ = preprocessor.preprocess(code) + + # Keywords should retain original case + assert "FOR" in result or "for" in result # May still be affected by other normalizations + + def test_disable_operator_normalization(self): + """Test disabling operator normalization.""" + config = PreprocessorConfig(normalize_operators=False) + preprocessor = Preprocessor(config) + + code = "x := 5" + result, _ = preprocessor.preprocess(code) + + # Should keep := operator + assert ":=" in result + + def test_disable_whitespace_normalization(self): + """Test disabling whitespace normalization.""" + config = PreprocessorConfig(normalize_whitespace=False) + preprocessor = Preprocessor(config) + + code = "x = 1" + result, _ = preprocessor.preprocess(code) + + # Multiple spaces should be preserved + assert " " in result + + def test_custom_tab_size(self): + """Test custom tab size.""" + config = PreprocessorConfig(tab_size=2) + preprocessor = Preprocessor(config) + + code = "FOR i = 1 TO n DO\n\tprint(i)" + result, _ = preprocessor.preprocess(code) + + # Tab should be converted to 2 spaces + lines = result.split('\n') + if len(lines) > 1: + assert lines[1].startswith(" ") or not lines[1].startswith(" ") + + +class TestComplexPseudocode: + """Tests for complex pseudocode examples.""" + + def test_full_algorithm_normalization(self, preprocessor, bubble_sort): + """Test normalization of complete bubble sort algorithm.""" + result, warnings = preprocessor.preprocess(bubble_sort) + + # Should have normalized keywords + assert "function" in result + assert "for" in result + assert "if" in result + + def test_mixed_style_normalization(self, preprocessor, mixed_case_keywords): + """Test normalization of mixed case keywords.""" + result, _ = preprocessor.preprocess(mixed_case_keywords) + + # All keywords should be lowercase + assert "FOR" not in result or "for" in result + assert "While" not in result or "while" in result + + def test_unicode_operators_normalization(self, preprocessor, unicode_operators): + """Test normalization of unicode operators.""" + result, _ = preprocessor.preprocess(unicode_operators) + + # Unicode operators should be converted + assert "←" not in result + assert "≤" not in result + assert "≥" not in result + assert "≠" not in result diff --git a/evaluation_function/tests/test_preview.py b/evaluation_function/tests/test_preview.py new file mode 100644 index 0000000..5fc505b --- /dev/null +++ b/evaluation_function/tests/test_preview.py @@ -0,0 +1,261 @@ +""" +Tests for the preview function. + +Tests cover: +- Basic preview functionality +- Different pseudocode styles +- Error handling +- Complexity detection in preview +""" + +import pytest + + +class MockParams: + """Mock params object for testing.""" + pass + + +@pytest.fixture +def params(): + """Default params fixture.""" + return MockParams() + + +class TestPreviewBasic: + """Basic preview function tests.""" + + def test_preview_simple_loop(self, params): + """Test preview of simple loop code.""" + from ..preview import preview_function + + code = """FOR i = 1 TO n DO + x = x + 1 +END FOR""" + result = preview_function(code, params) + + assert "preview" in result + preview = result["preview"] + assert preview is not None + assert "feedback" in preview + assert "Parsing: Successful" in preview["feedback"] + assert "Time Complexity" in preview["feedback"] + + def test_preview_nested_loops(self, params): + """Test preview detects nested loops.""" + from ..preview import preview_function + + code = """FOR i = 1 TO n DO + FOR j = 1 TO n DO + sum = sum + A[i][j] + END FOR +END FOR""" + result = preview_function(code, params) + + preview = result["preview"] + assert "Nested loops" in preview["feedback"] + assert "O(n²)" in preview["feedback"] + + def test_preview_recursion(self, params): + """Test preview detects recursion.""" + from ..preview import preview_function + + code = """FUNCTION factorial(n) + IF n <= 1 THEN + RETURN 1 + END IF + RETURN n * factorial(n - 1) +END FUNCTION""" + result = preview_function(code, params) + + preview = result["preview"] + assert "Recursion" in preview["feedback"] + + +class TestPreviewEmptyInput: + """Test handling of empty/invalid input.""" + + def test_preview_empty_input(self, params): + """Test preview handles empty input.""" + from ..preview import preview_function + + result = preview_function("", params) + + preview = result["preview"] + assert "Please enter your pseudocode" in preview["feedback"] + + def test_preview_whitespace_only(self, params): + """Test preview handles whitespace-only input.""" + from ..preview import preview_function + + result = preview_function(" \n\n ", params) + + preview = result["preview"] + assert "Please enter your pseudocode" in preview["feedback"] + + def test_preview_invalid_response_type(self, params): + """Test preview handles invalid response type.""" + from ..preview import preview_function + + result = preview_function(12345, params) + + preview = result["preview"] + assert "Invalid response format" in preview["feedback"] + + +class TestPreviewDictInput: + """Test preview with dict response formats.""" + + def test_preview_dict_response(self, params): + """Test preview accepts dict response format.""" + from ..preview import preview_function + + response = {"pseudocode": "FOR i = 1 TO n DO\n x = x + 1\nEND FOR"} + result = preview_function(response, params) + + preview = result["preview"] + assert "Parsing: Successful" in preview["feedback"] + + def test_preview_with_code_key(self, params): + """Test preview accepts 'code' key in dict.""" + from ..preview import preview_function + + response = {"code": "FOR i = 1 TO n DO\n x = x + 1\nEND FOR"} + result = preview_function(response, params) + + preview = result["preview"] + assert "Parsing: Successful" in preview["feedback"] + + +class TestPreviewLoopTypes: + """Test preview with different loop types.""" + + def test_preview_curly_brace_syntax(self, params): + """Test preview handles curly brace block syntax.""" + from ..preview import preview_function + + code = """FOR i = 1 TO n { + x = x + 1 +}""" + result = preview_function(code, params) + + preview = result["preview"] + assert "Parsing: Successful" in preview["feedback"] + assert "Loops" in preview["feedback"] + + def test_preview_while_loop(self, params): + """Test preview detects while loop.""" + from ..preview import preview_function + + code = """WHILE x > 0 DO + x = x - 1 +END WHILE""" + result = preview_function(code, params) + + preview = result["preview"] + assert "Loops" in preview["feedback"] + + +class TestPreviewComplexityDetection: + """Test complexity detection in preview.""" + + def test_preview_binary_search(self, params): + """Test preview analyzes binary search correctly.""" + from ..preview import preview_function + + code = """FUNCTION binarySearch(A, target, low, high) + IF low > high THEN + RETURN -1 + END IF + mid = (low + high) / 2 + IF A[mid] == target THEN + RETURN mid + ELSE IF A[mid] < target THEN + RETURN binarySearch(A, target, mid + 1, high) + ELSE + RETURN binarySearch(A, target, low, mid - 1) + END IF +END FUNCTION""" + result = preview_function(code, params) + + preview = result["preview"] + assert "Recursion" in preview["feedback"] + assert "O(log n)" in preview["feedback"] + + def test_preview_constant_complexity(self, params): + """Test preview detects constant complexity.""" + from ..preview import preview_function + + code = """x = 1 +y = 2 +z = x + y""" + result = preview_function(code, params) + + preview = result["preview"] + assert "O(1)" in preview["feedback"] + + def test_preview_merge_sort_pattern(self, params): + """Test preview detects merge sort pattern.""" + from ..preview import preview_function + + code = """FUNCTION mergeSort(A, low, high) + IF low < high THEN + mid = (low + high) / 2 + mergeSort(A, low, mid) + mergeSort(A, mid + 1, high) + merge(A, low, mid, high) + END IF +END FUNCTION""" + result = preview_function(code, params) + + preview = result["preview"] + assert "O(n log n)" in preview["feedback"] + + +class TestPreviewLatex: + """Test LaTeX output in preview.""" + + def test_preview_latex_output(self, params): + """Test preview includes LaTeX formatted complexity.""" + from ..preview import preview_function + + code = "FOR i = 1 TO n DO\n x = 1\nEND FOR" + result = preview_function(code, params) + + preview = result["preview"] + assert "latex" in preview + assert "Time Complexity" in preview["latex"] + + +class TestPreviewEdgeCases: + """Edge case tests for preview function.""" + + def test_preview_very_long_code(self, params): + """Test preview handles long code.""" + from ..preview import preview_function + + lines = ["FOR i = 1 TO n DO"] + for j in range(50): + lines.append(f" x{j} = {j}") + lines.append("END FOR") + code = "\n".join(lines) + + result = preview_function(code, params) + preview = result["preview"] + assert "Parsing: Successful" in preview["feedback"] + + def test_preview_deeply_nested(self, params): + """Test preview handles deeply nested loops.""" + from ..preview import preview_function + + code = """FOR i = 1 TO n DO + FOR j = 1 TO n DO + FOR k = 1 TO n DO + x = x + 1 + END FOR + END FOR +END FOR""" + result = preview_function(code, params) + + preview = result["preview"] + assert "depth" in preview["feedback"].lower()