initial
This commit is contained in:
222
scripts/extract_solutions.py
Normal file
222
scripts/extract_solutions.py
Normal file
@ -0,0 +1,222 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Extract test outputs from ARC puzzle JSON and store in solution column
|
||||
This script adds a 'solution' column and populates it with the test output grid
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import pymysql
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Force unbuffered output
|
||||
sys.stdout.reconfigure(line_buffering=True)
|
||||
sys.stderr.reconfigure(line_buffering=True)
|
||||
|
||||
|
||||
def load_env_config():
|
||||
"""Load database configuration from .env file"""
|
||||
load_dotenv()
|
||||
|
||||
config = {
|
||||
'host': os.getenv('DB_HOST'),
|
||||
'user': os.getenv('DB_USER'),
|
||||
'password': os.getenv('DB_PASSWORD'),
|
||||
'database': os.getenv('DB_NAME'),
|
||||
'port': int(os.getenv('DB_PORT', 3306)),
|
||||
'charset': 'utf8mb4'
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def check_and_add_solution_column(cursor):
|
||||
"""Check if solution column exists, add it if not"""
|
||||
cursor.execute("""
|
||||
SELECT COLUMN_NAME
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA = DATABASE()
|
||||
AND TABLE_NAME = 'arc_jsons'
|
||||
AND COLUMN_NAME = 'solution'
|
||||
""")
|
||||
|
||||
if cursor.fetchone():
|
||||
print("✓ Column 'solution' already exists")
|
||||
return True
|
||||
|
||||
print("Adding 'solution' column to arc_jsons table...")
|
||||
cursor.execute("ALTER TABLE arc_jsons ADD COLUMN solution JSON AFTER json")
|
||||
print("✓ Column 'solution' added successfully")
|
||||
return True
|
||||
|
||||
|
||||
def extract_test_output(json_content):
|
||||
"""Extract test output from puzzle JSON"""
|
||||
try:
|
||||
# Parse JSON (handle both string and object)
|
||||
if isinstance(json_content, str):
|
||||
puzzle_data = json.loads(json_content)
|
||||
else:
|
||||
puzzle_data = json_content
|
||||
|
||||
# Extract test output
|
||||
if puzzle_data.get('test') and len(puzzle_data['test']) > 0:
|
||||
test_case = puzzle_data['test'][0]
|
||||
if test_case.get('output'):
|
||||
return test_case['output']
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error parsing JSON: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def update_solution(cursor, puzzle_id, json_content):
|
||||
"""Extract test output and update solution column"""
|
||||
test_output = extract_test_output(json_content)
|
||||
|
||||
if test_output is None:
|
||||
return False, "No test output found"
|
||||
|
||||
try:
|
||||
# Convert to JSON string
|
||||
solution_json = json.dumps(test_output)
|
||||
|
||||
# Update database
|
||||
sql = "UPDATE arc_jsons SET solution = %s WHERE id = %s"
|
||||
cursor.execute(sql, (solution_json, puzzle_id))
|
||||
|
||||
return True, None
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
|
||||
def main():
|
||||
print("ARC Solution Extraction Tool")
|
||||
print("=" * 60)
|
||||
print("This script will:")
|
||||
print(" 1. Add a 'solution' column to arc_jsons (if needed)")
|
||||
print(" 2. Extract test outputs from JSON data")
|
||||
print(" 3. Store solutions in the new column")
|
||||
print("=" * 60)
|
||||
|
||||
# Load configuration
|
||||
try:
|
||||
config = load_env_config()
|
||||
print(f"\n✓ Loaded configuration from .env")
|
||||
print(f" Host: {config['host']}")
|
||||
print(f" Database: {config['database']}")
|
||||
print(f" User: {config['user']}")
|
||||
except Exception as e:
|
||||
print(f"✗ Error loading configuration: {e}")
|
||||
return 1
|
||||
|
||||
# Connect to database
|
||||
try:
|
||||
print(f"\nConnecting to database...")
|
||||
connection = pymysql.connect(**config)
|
||||
print(f"✓ Connected successfully")
|
||||
except Exception as e:
|
||||
print(f"✗ Database connection failed: {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
cursor = connection.cursor()
|
||||
|
||||
# Check if table exists
|
||||
cursor.execute("SHOW TABLES LIKE 'arc_jsons'")
|
||||
if not cursor.fetchone():
|
||||
print(f"✗ Table 'arc_jsons' does not exist")
|
||||
return 1
|
||||
|
||||
# Get current count
|
||||
cursor.execute("SELECT COUNT(*) FROM arc_jsons")
|
||||
total_count = cursor.fetchone()[0]
|
||||
print(f"✓ Table 'arc_jsons' found ({total_count} records)")
|
||||
|
||||
# Check/add solution column
|
||||
print()
|
||||
check_and_add_solution_column(cursor)
|
||||
connection.commit()
|
||||
|
||||
# Ask for confirmation (unless --yes flag is provided)
|
||||
if '--yes' not in sys.argv:
|
||||
print(f"\n⚠ About to process {total_count} records")
|
||||
response = input("Continue? (yes/no): ").strip().lower()
|
||||
if response not in ['yes', 'y']:
|
||||
print("Extraction cancelled")
|
||||
return 0
|
||||
else:
|
||||
print(f"\n⚠ About to process {total_count} records (auto-confirmed with --yes flag)")
|
||||
|
||||
# Fetch all records
|
||||
print(f"\nFetching records...")
|
||||
cursor.execute("SELECT id, json FROM arc_jsons")
|
||||
records = cursor.fetchall()
|
||||
print(f"✓ Retrieved {len(records)} records")
|
||||
|
||||
# Process each record
|
||||
print(f"\nProcessing records...")
|
||||
updated = 0
|
||||
errors = 0
|
||||
no_output = 0
|
||||
|
||||
for i, (puzzle_id, json_content) in enumerate(records, 1):
|
||||
success, error = update_solution(cursor, puzzle_id, json_content)
|
||||
|
||||
if success:
|
||||
updated += 1
|
||||
elif error == "No test output found":
|
||||
no_output += 1
|
||||
if no_output <= 5: # Show first 5 cases
|
||||
print(f" ⚠ No output: {puzzle_id}")
|
||||
else:
|
||||
errors += 1
|
||||
if errors <= 5: # Show first 5 errors
|
||||
print(f" ✗ Error {puzzle_id}: {error}")
|
||||
|
||||
# Show progress every 100 records
|
||||
if i % 100 == 0 or i == len(records):
|
||||
print(f" Progress: {i}/{len(records)} ({updated} updated, {no_output} no output, {errors} errors)")
|
||||
|
||||
# Commit the transaction
|
||||
connection.commit()
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"✓ Extraction complete!")
|
||||
print(f" Successfully updated: {updated}")
|
||||
print(f" No test output: {no_output}")
|
||||
print(f" Errors: {errors}")
|
||||
print(f" Total processed: {len(records)}")
|
||||
|
||||
# Show sample
|
||||
if updated > 0:
|
||||
print(f"\nSample record (first with solution):")
|
||||
cursor.execute("SELECT id, solution FROM arc_jsons WHERE solution IS NOT NULL LIMIT 1")
|
||||
sample = cursor.fetchone()
|
||||
if sample:
|
||||
sample_id, sample_solution = sample
|
||||
solution_data = json.loads(sample_solution)
|
||||
rows = len(solution_data)
|
||||
cols = len(solution_data[0]) if rows > 0 else 0
|
||||
print(f" ID: {sample_id}")
|
||||
print(f" Solution grid: {cols}×{rows}")
|
||||
print(f" First row: {solution_data[0][:10]}..." if cols > 10 else f" First row: {solution_data[0]}")
|
||||
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
print(f"\n✗ Error during extraction: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
finally:
|
||||
connection.close()
|
||||
print(f"\n✓ Database connection closed")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user