diff --git a/Cargo.lock b/Cargo.lock index 92dd6de..714e7f0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1322,7 +1322,7 @@ checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" [[package]] name = "rust_pgn_reader_python_binding" -version = "3.3.0" +version = "3.4.0" dependencies = [ "arrow", "arrow-array", diff --git a/Cargo.toml b/Cargo.toml index 041a994..88f0eec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rust_pgn_reader_python_binding" -version = "3.3.0" +version = "3.4.0" edition = "2024" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/example_chess960.py b/src/example_chess960.py new file mode 100644 index 0000000..5105556 --- /dev/null +++ b/src/example_chess960.py @@ -0,0 +1,28 @@ +import rust_pgn_reader_python_binding + +pgn_moves = """ +[Event "lc0MG"] +[Site "internet"] +[Date "????.??.??"] +[Round "-"] +[White "lc0.net.714559"] +[Black "lc0.net.714558"] +[Result "1/2-1/2"] +[Variant "chess960"] +[FEN "brkrqnnb/pppppppp/8/8/8/8/PPPPPPPP/BRKRQNNB w KQkq - 0 1"] + +1.g3 d5 2.d4 g6 3.b3 Nf6 4.Ne3 b6 5.Nh3 Ne6 6.f4 Ng7 7.g4 h6 8.Nf2 Ne6 9.f5 Nf4 10.Nf1 gxf5 11.Qd2 Ne6 12.gxf5 Ng5 13.Ng3 Qd7 14.Rg1 Rg8 15.O-O-O O-O-O 16.Kb1 Kb8 17.Bb2 h5 18.Qf4 Bb7 19.h4 Nge4 20.Nfxe4 dxe4 21.Nxe4 Nxe4 22.Bxe4 Bxe4 23.Qxe4 e6 24.Rxg8 Rxg8 25.fxe6 fxe6 26.Rf1 Qe7 27.e3 Bf6 28.Ba3 Qd8 29.Qxe6 Bxh4 30.Rf5 Bg5 31.d5 Qc8 32.e4 h4 33.Qf7 Be3 34.d6 Rg1+ 35.Rf1 Rxf1+ 36.Qxf1 h3 37.dxc7+ Kxc7 38.Qc4+ Kb7 39.Qf7+ Qc7 40.Qd5+ Qc6 41.Qf7+ Qc7 42.Qd5+ Qc6 43.Qf7+ Qc7 1/2-1/2 {OL: 0} +""" + +extractor = rust_pgn_reader_python_binding.parse_game(pgn_moves) + +print("moves", extractor.moves) +print("comments", extractor.comments) +print("valid", extractor.valid_moves) +# print(extractor.evals) +print("clock", extractor.clock_times) +# print(extractor.outcome) +# print(extractor.position_status.is_checkmate) +# print(extractor.position_status.is_stalemate) +# print(extractor.position_status.is_game_over) +# print(extractor.position_status.legal_move_count) diff --git a/src/lib.rs b/src/lib.rs index 05b572b..d38b337 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,9 @@ use pyo3::prelude::*; use pyo3_arrow::PyChunkedArray; use rayon::ThreadPoolBuilder; use rayon::prelude::*; +use shakmaty::CastlingMode; use shakmaty::Color; +use shakmaty::fen::Fen; use shakmaty::{Chess, Position, Role, Square, uci::UciMove}; use std::io::Cursor; use std::ops::ControlFlow; @@ -274,13 +276,54 @@ impl Visitor for MoveExtractor { self.moves.clear(); self.flat_legal_moves.clear(); self.legal_moves_offsets.clear(); - self.pos = Chess::default(); self.valid_moves = true; self.comments.clear(); self.evals.clear(); self.clock_times.clear(); self.castling_rights.clear(); + // Determine castling mode from Variant header (case-insensitive) + let castling_mode = self + .headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("Variant")) + .and_then(|(_, v)| { + let v_lower = v.to_lowercase(); + if v_lower == "chess960" { + Some(CastlingMode::Chess960) + } else { + None + } + }) + .unwrap_or(CastlingMode::Standard); + + // Try to parse FEN from headers, fall back to default position + let fen_header = self + .headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("FEN")) + .map(|(_, v)| v.as_str()); + + if let Some(fen_str) = fen_header { + match fen_str.parse::() { + Ok(fen) => match fen.into_position(castling_mode) { + Ok(pos) => self.pos = pos, + Err(e) => { + eprintln!("invalid FEN position: {}", e); + self.pos = Chess::default(); + self.valid_moves = false; + } + }, + Err(e) => { + eprintln!("failed to parse FEN: {}", e); + self.pos = Chess::default(); + self.valid_moves = false; + } + } + } else { + self.pos = Chess::default(); + } + self.push_castling_bitboards(); if self.store_legal_moves { self.push_legal_moves(); @@ -636,4 +679,98 @@ mod pyucimove_tests { assert_eq!(extractor.moves.len(), 7); assert_eq!(extractor.outcome, Some("Black".to_string())); } + + #[test] + fn test_parse_game_with_standard_fen() { + // A game starting from a mid-game position + let pgn = r#"[FEN "r1bqkbnr/pppp1ppp/2n5/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3"] + +3. Bb5 a6 4. Ba4 Nf6 1-0"#; + let result = parse_single_game_native(pgn, false); + assert!(result.is_ok()); + let extractor = result.unwrap(); + assert!(extractor.valid_moves, "Moves should be valid"); + assert_eq!(extractor.moves.len(), 4); + } + + #[test] + fn test_parse_chess960_game() { + // Chess960 game with custom starting position + let pgn = r#"[Variant "chess960"] +[FEN "brkrqnnb/pppppppp/8/8/8/8/PPPPPPPP/BRKRQNNB w KQkq - 0 1"] + +1. g3 d5 2. d4 g6 3. b3 Nf6 1-0"#; + let result = parse_single_game_native(pgn, false); + assert!(result.is_ok()); + let extractor = result.unwrap(); + assert!( + extractor.valid_moves, + "Chess960 moves should be valid with proper FEN" + ); + assert_eq!(extractor.moves.len(), 6); + } + + #[test] + fn test_parse_chess960_variant_case_insensitive() { + // Test that variant detection is case-insensitive + let pgn = r#"[Variant "Chess960"] +[FEN "brkrqnnb/pppppppp/8/8/8/8/PPPPPPPP/BRKRQNNB w KQkq - 0 1"] + +1. g3 d5 1-0"#; + let result = parse_single_game_native(pgn, false); + assert!(result.is_ok()); + let extractor = result.unwrap(); + assert!( + extractor.valid_moves, + "Should handle Chess960 case variations" + ); + } + + #[test] + fn test_parse_invalid_fen_falls_back() { + // Invalid FEN should fall back to default and mark invalid + let pgn = r#"[FEN "invalid fen string"] + +1. e4 e5 1-0"#; + let result = parse_single_game_native(pgn, false); + assert!(result.is_ok()); + let extractor = result.unwrap(); + assert!( + !extractor.valid_moves, + "Should mark as invalid when FEN parsing fails" + ); + } + + #[test] + fn test_fen_header_case_insensitive() { + // FEN header key should be case-insensitive + let pgn = r#"[fen "r1bqkbnr/pppp1ppp/2n5/4p3/4P3/5N2/PPPP1PPP/RNBQKB1R w KQkq - 2 3"] + +3. Bb5 1-0"#; + let result = parse_single_game_native(pgn, false); + assert!(result.is_ok()); + let extractor = result.unwrap(); + assert!( + extractor.valid_moves, + "Should handle lowercase 'fen' header" + ); + } + + #[test] + fn test_parse_game_with_custom_fen_no_variant() { + // A standard chess game starting from a mid-game position (no Variant header) + // Position after 1.e4 e5 2.Nf3 Nc6 3.Bb5 (Ruy Lopez) + let pgn = r#"[Event "Test Game"] + [FEN "r1bqkbnr/pppp1ppp/2n5/1B2p3/4P3/5N2/PPPP1PPP/RNBQK2R b KQkq - 3 3"] + + 3... a6 4. Ba4 Nf6 5. O-O Be7 1-0"#; + let result = parse_single_game_native(pgn, false); + assert!(result.is_ok()); + let extractor = result.unwrap(); + assert!( + extractor.valid_moves, + "Standard game with custom FEN should be valid" + ); + assert_eq!(extractor.moves.len(), 5); // a6, Ba4, Nf6, O-O, Be7 + } }