Files
arc-humans-interface-db/scripts/extract_solutions.py
bmachado b8856c0660 initial
2025-11-05 00:24:05 +00:00

223 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Extract test outputs from ARC puzzle JSON and store in solution column
This script adds a 'solution' column and populates it with the test output grid
"""
import json
import os
import sys
from pathlib import Path
import pymysql
from dotenv import load_dotenv
# Force unbuffered output
sys.stdout.reconfigure(line_buffering=True)
sys.stderr.reconfigure(line_buffering=True)
def load_env_config():
"""Load database configuration from .env file"""
load_dotenv()
config = {
'host': os.getenv('DB_HOST'),
'user': os.getenv('DB_USER'),
'password': os.getenv('DB_PASSWORD'),
'database': os.getenv('DB_NAME'),
'port': int(os.getenv('DB_PORT', 3306)),
'charset': 'utf8mb4'
}
return config
def check_and_add_solution_column(cursor):
"""Check if solution column exists, add it if not"""
cursor.execute("""
SELECT COLUMN_NAME
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = DATABASE()
AND TABLE_NAME = 'arc_jsons'
AND COLUMN_NAME = 'solution'
""")
if cursor.fetchone():
print("✓ Column 'solution' already exists")
return True
print("Adding 'solution' column to arc_jsons table...")
cursor.execute("ALTER TABLE arc_jsons ADD COLUMN solution JSON AFTER json")
print("✓ Column 'solution' added successfully")
return True
def extract_test_output(json_content):
"""Extract test output from puzzle JSON"""
try:
# Parse JSON (handle both string and object)
if isinstance(json_content, str):
puzzle_data = json.loads(json_content)
else:
puzzle_data = json_content
# Extract test output
if puzzle_data.get('test') and len(puzzle_data['test']) > 0:
test_case = puzzle_data['test'][0]
if test_case.get('output'):
return test_case['output']
return None
except Exception as e:
print(f"Error parsing JSON: {e}")
return None
def update_solution(cursor, puzzle_id, json_content):
"""Extract test output and update solution column"""
test_output = extract_test_output(json_content)
if test_output is None:
return False, "No test output found"
try:
# Convert to JSON string
solution_json = json.dumps(test_output)
# Update database
sql = "UPDATE arc_jsons SET solution = %s WHERE id = %s"
cursor.execute(sql, (solution_json, puzzle_id))
return True, None
except Exception as e:
return False, str(e)
def main():
print("ARC Solution Extraction Tool")
print("=" * 60)
print("This script will:")
print(" 1. Add a 'solution' column to arc_jsons (if needed)")
print(" 2. Extract test outputs from JSON data")
print(" 3. Store solutions in the new column")
print("=" * 60)
# Load configuration
try:
config = load_env_config()
print(f"\n✓ Loaded configuration from .env")
print(f" Host: {config['host']}")
print(f" Database: {config['database']}")
print(f" User: {config['user']}")
except Exception as e:
print(f"✗ Error loading configuration: {e}")
return 1
# Connect to database
try:
print(f"\nConnecting to database...")
connection = pymysql.connect(**config)
print(f"✓ Connected successfully")
except Exception as e:
print(f"✗ Database connection failed: {e}")
return 1
try:
cursor = connection.cursor()
# Check if table exists
cursor.execute("SHOW TABLES LIKE 'arc_jsons'")
if not cursor.fetchone():
print(f"✗ Table 'arc_jsons' does not exist")
return 1
# Get current count
cursor.execute("SELECT COUNT(*) FROM arc_jsons")
total_count = cursor.fetchone()[0]
print(f"✓ Table 'arc_jsons' found ({total_count} records)")
# Check/add solution column
print()
check_and_add_solution_column(cursor)
connection.commit()
# Ask for confirmation (unless --yes flag is provided)
if '--yes' not in sys.argv:
print(f"\n⚠ About to process {total_count} records")
response = input("Continue? (yes/no): ").strip().lower()
if response not in ['yes', 'y']:
print("Extraction cancelled")
return 0
else:
print(f"\n⚠ About to process {total_count} records (auto-confirmed with --yes flag)")
# Fetch all records
print(f"\nFetching records...")
cursor.execute("SELECT id, json FROM arc_jsons")
records = cursor.fetchall()
print(f"✓ Retrieved {len(records)} records")
# Process each record
print(f"\nProcessing records...")
updated = 0
errors = 0
no_output = 0
for i, (puzzle_id, json_content) in enumerate(records, 1):
success, error = update_solution(cursor, puzzle_id, json_content)
if success:
updated += 1
elif error == "No test output found":
no_output += 1
if no_output <= 5: # Show first 5 cases
print(f" ⚠ No output: {puzzle_id}")
else:
errors += 1
if errors <= 5: # Show first 5 errors
print(f" ✗ Error {puzzle_id}: {error}")
# Show progress every 100 records
if i % 100 == 0 or i == len(records):
print(f" Progress: {i}/{len(records)} ({updated} updated, {no_output} no output, {errors} errors)")
# Commit the transaction
connection.commit()
print(f"\n{'=' * 60}")
print(f"✓ Extraction complete!")
print(f" Successfully updated: {updated}")
print(f" No test output: {no_output}")
print(f" Errors: {errors}")
print(f" Total processed: {len(records)}")
# Show sample
if updated > 0:
print(f"\nSample record (first with solution):")
cursor.execute("SELECT id, solution FROM arc_jsons WHERE solution IS NOT NULL LIMIT 1")
sample = cursor.fetchone()
if sample:
sample_id, sample_solution = sample
solution_data = json.loads(sample_solution)
rows = len(solution_data)
cols = len(solution_data[0]) if rows > 0 else 0
print(f" ID: {sample_id}")
print(f" Solution grid: {cols}×{rows}")
print(f" First row: {solution_data[0][:10]}..." if cols > 10 else f" First row: {solution_data[0]}")
except Exception as e:
connection.rollback()
print(f"\n✗ Error during extraction: {e}")
import traceback
traceback.print_exc()
return 1
finally:
connection.close()
print(f"\n✓ Database connection closed")
return 0
if __name__ == '__main__':
sys.exit(main())