#!/usr/bin/env python3 """ Fix ConceptArc solutions in the database to include ALL test outputs, not just the last one """ import json import os import sys from pathlib import Path import pymysql from dotenv import load_dotenv def load_env_config(): """Load database configuration from .env file""" load_dotenv() return { 'host': os.getenv('DB_HOST'), 'user': os.getenv('DB_USER'), 'password': os.getenv('DB_PASSWORD'), 'database': os.getenv('DB_NAME'), 'port': int(os.getenv('DB_PORT', 3306)), 'charset': 'utf8mb4' } def main(): print("Fixing ConceptArc Solutions") print("=" * 50) config = load_env_config() connection = pymysql.connect(**config) cursor = connection.cursor() try: # Get all ConceptArc entries cursor.execute(""" SELECT aj.id, aj.arc_puzzle_id, aj.json, aj.solution FROM arc_jsons aj JOIN arc_puzzles ap ON aj.arc_puzzle_id = ap.id WHERE ap.corpora = 'ConceptArc' ORDER BY aj.arc_puzzle_id """) entries = cursor.fetchall() print(f"Found {len(entries)} ConceptArc entries to check") if not entries: print("No ConceptArc entries found!") return 0 # Ask for confirmation if '--yes' not in sys.argv: response = input(f"\nUpdate solutions for {len(entries)} entries? (yes/no): ").strip().lower() if response not in ['yes', 'y']: print("Operation cancelled") return 0 else: print(f"Updating solutions (auto-confirmed with --yes flag)") updated = 0 errors = 0 skipped = 0 print("\nProcessing entries...") for row_id, puzzle_id, json_str, current_solution in entries: try: # Parse the puzzle JSON puzzle_data = json.loads(json_str) # Extract all test outputs all_outputs = [] if 'test' in puzzle_data: for test_case in puzzle_data['test']: if 'output' in test_case: all_outputs.append(test_case['output']) if not all_outputs: print(f" ⚠ {puzzle_id}: No test outputs found") skipped += 1 continue # Create new solution as array of all outputs new_solution = json.dumps(all_outputs) # Check if it's different from current if current_solution == new_solution: skipped += 1 continue # Update the solution cursor.execute(""" UPDATE arc_jsons SET solution = %s WHERE id = %s """, (new_solution, row_id)) updated += 1 if updated % 20 == 0: print(f" Updated: {updated}/{len(entries)}") except Exception as e: errors += 1 print(f" ✗ Error with {puzzle_id}: {e}") if errors > 10: print("Too many errors, stopping...") break # Commit changes connection.commit() print(f"\n{'=' * 50}") print(f"✓ Solution update complete!") print(f" Updated: {updated}") print(f" Skipped (unchanged): {skipped}") print(f" Errors: {errors}") # Show a sample of updated solutions if updated > 0: cursor.execute(""" SELECT aj.arc_puzzle_id, aj.solution FROM arc_jsons aj JOIN arc_puzzles ap ON aj.arc_puzzle_id = ap.id WHERE ap.corpora = 'ConceptArc' LIMIT 3 """) print(f"\n{'=' * 50}") print("Sample updated entries:") print("=" * 50) for puzzle_id, solution in cursor.fetchall(): if solution: sol_data = json.loads(solution) print(f"\n{puzzle_id}:") print(f" Number of test outputs: {len(sol_data)}") if isinstance(sol_data, list) and len(sol_data) > 0: first_output = sol_data[0] if isinstance(first_output, list): print(f" First output dimensions: {len(first_output)}x{len(first_output[0]) if first_output else 0}") except Exception as e: connection.rollback() print(f"\n✗ Error: {e}") import traceback traceback.print_exc() return 1 finally: connection.close() print(f"\n✓ Database connection closed") return 0 if __name__ == '__main__': sys.exit(main())