223 lines
6.9 KiB
Python
223 lines
6.9 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Extract test outputs from ARC puzzle JSON and store in solution column
|
||
This script adds a 'solution' column and populates it with the test output grid
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
import pymysql
|
||
from dotenv import load_dotenv
|
||
|
||
# Force unbuffered output
|
||
sys.stdout.reconfigure(line_buffering=True)
|
||
sys.stderr.reconfigure(line_buffering=True)
|
||
|
||
|
||
def load_env_config():
|
||
"""Load database configuration from .env file"""
|
||
load_dotenv()
|
||
|
||
config = {
|
||
'host': os.getenv('DB_HOST'),
|
||
'user': os.getenv('DB_USER'),
|
||
'password': os.getenv('DB_PASSWORD'),
|
||
'database': os.getenv('DB_NAME'),
|
||
'port': int(os.getenv('DB_PORT', 3306)),
|
||
'charset': 'utf8mb4'
|
||
}
|
||
|
||
return config
|
||
|
||
|
||
def check_and_add_solution_column(cursor):
|
||
"""Check if solution column exists, add it if not"""
|
||
cursor.execute("""
|
||
SELECT COLUMN_NAME
|
||
FROM INFORMATION_SCHEMA.COLUMNS
|
||
WHERE TABLE_SCHEMA = DATABASE()
|
||
AND TABLE_NAME = 'arc_jsons'
|
||
AND COLUMN_NAME = 'solution'
|
||
""")
|
||
|
||
if cursor.fetchone():
|
||
print("✓ Column 'solution' already exists")
|
||
return True
|
||
|
||
print("Adding 'solution' column to arc_jsons table...")
|
||
cursor.execute("ALTER TABLE arc_jsons ADD COLUMN solution JSON AFTER json")
|
||
print("✓ Column 'solution' added successfully")
|
||
return True
|
||
|
||
|
||
def extract_test_output(json_content):
|
||
"""Extract test output from puzzle JSON"""
|
||
try:
|
||
# Parse JSON (handle both string and object)
|
||
if isinstance(json_content, str):
|
||
puzzle_data = json.loads(json_content)
|
||
else:
|
||
puzzle_data = json_content
|
||
|
||
# Extract test output
|
||
if puzzle_data.get('test') and len(puzzle_data['test']) > 0:
|
||
test_case = puzzle_data['test'][0]
|
||
if test_case.get('output'):
|
||
return test_case['output']
|
||
|
||
return None
|
||
except Exception as e:
|
||
print(f"Error parsing JSON: {e}")
|
||
return None
|
||
|
||
|
||
def update_solution(cursor, puzzle_id, json_content):
|
||
"""Extract test output and update solution column"""
|
||
test_output = extract_test_output(json_content)
|
||
|
||
if test_output is None:
|
||
return False, "No test output found"
|
||
|
||
try:
|
||
# Convert to JSON string
|
||
solution_json = json.dumps(test_output)
|
||
|
||
# Update database
|
||
sql = "UPDATE arc_jsons SET solution = %s WHERE id = %s"
|
||
cursor.execute(sql, (solution_json, puzzle_id))
|
||
|
||
return True, None
|
||
except Exception as e:
|
||
return False, str(e)
|
||
|
||
|
||
def main():
|
||
print("ARC Solution Extraction Tool")
|
||
print("=" * 60)
|
||
print("This script will:")
|
||
print(" 1. Add a 'solution' column to arc_jsons (if needed)")
|
||
print(" 2. Extract test outputs from JSON data")
|
||
print(" 3. Store solutions in the new column")
|
||
print("=" * 60)
|
||
|
||
# Load configuration
|
||
try:
|
||
config = load_env_config()
|
||
print(f"\n✓ Loaded configuration from .env")
|
||
print(f" Host: {config['host']}")
|
||
print(f" Database: {config['database']}")
|
||
print(f" User: {config['user']}")
|
||
except Exception as e:
|
||
print(f"✗ Error loading configuration: {e}")
|
||
return 1
|
||
|
||
# Connect to database
|
||
try:
|
||
print(f"\nConnecting to database...")
|
||
connection = pymysql.connect(**config)
|
||
print(f"✓ Connected successfully")
|
||
except Exception as e:
|
||
print(f"✗ Database connection failed: {e}")
|
||
return 1
|
||
|
||
try:
|
||
cursor = connection.cursor()
|
||
|
||
# Check if table exists
|
||
cursor.execute("SHOW TABLES LIKE 'arc_jsons'")
|
||
if not cursor.fetchone():
|
||
print(f"✗ Table 'arc_jsons' does not exist")
|
||
return 1
|
||
|
||
# Get current count
|
||
cursor.execute("SELECT COUNT(*) FROM arc_jsons")
|
||
total_count = cursor.fetchone()[0]
|
||
print(f"✓ Table 'arc_jsons' found ({total_count} records)")
|
||
|
||
# Check/add solution column
|
||
print()
|
||
check_and_add_solution_column(cursor)
|
||
connection.commit()
|
||
|
||
# Ask for confirmation (unless --yes flag is provided)
|
||
if '--yes' not in sys.argv:
|
||
print(f"\n⚠ About to process {total_count} records")
|
||
response = input("Continue? (yes/no): ").strip().lower()
|
||
if response not in ['yes', 'y']:
|
||
print("Extraction cancelled")
|
||
return 0
|
||
else:
|
||
print(f"\n⚠ About to process {total_count} records (auto-confirmed with --yes flag)")
|
||
|
||
# Fetch all records
|
||
print(f"\nFetching records...")
|
||
cursor.execute("SELECT id, json FROM arc_jsons")
|
||
records = cursor.fetchall()
|
||
print(f"✓ Retrieved {len(records)} records")
|
||
|
||
# Process each record
|
||
print(f"\nProcessing records...")
|
||
updated = 0
|
||
errors = 0
|
||
no_output = 0
|
||
|
||
for i, (puzzle_id, json_content) in enumerate(records, 1):
|
||
success, error = update_solution(cursor, puzzle_id, json_content)
|
||
|
||
if success:
|
||
updated += 1
|
||
elif error == "No test output found":
|
||
no_output += 1
|
||
if no_output <= 5: # Show first 5 cases
|
||
print(f" ⚠ No output: {puzzle_id}")
|
||
else:
|
||
errors += 1
|
||
if errors <= 5: # Show first 5 errors
|
||
print(f" ✗ Error {puzzle_id}: {error}")
|
||
|
||
# Show progress every 100 records
|
||
if i % 100 == 0 or i == len(records):
|
||
print(f" Progress: {i}/{len(records)} ({updated} updated, {no_output} no output, {errors} errors)")
|
||
|
||
# Commit the transaction
|
||
connection.commit()
|
||
|
||
print(f"\n{'=' * 60}")
|
||
print(f"✓ Extraction complete!")
|
||
print(f" Successfully updated: {updated}")
|
||
print(f" No test output: {no_output}")
|
||
print(f" Errors: {errors}")
|
||
print(f" Total processed: {len(records)}")
|
||
|
||
# Show sample
|
||
if updated > 0:
|
||
print(f"\nSample record (first with solution):")
|
||
cursor.execute("SELECT id, solution FROM arc_jsons WHERE solution IS NOT NULL LIMIT 1")
|
||
sample = cursor.fetchone()
|
||
if sample:
|
||
sample_id, sample_solution = sample
|
||
solution_data = json.loads(sample_solution)
|
||
rows = len(solution_data)
|
||
cols = len(solution_data[0]) if rows > 0 else 0
|
||
print(f" ID: {sample_id}")
|
||
print(f" Solution grid: {cols}×{rows}")
|
||
print(f" First row: {solution_data[0][:10]}..." if cols > 10 else f" First row: {solution_data[0]}")
|
||
|
||
except Exception as e:
|
||
connection.rollback()
|
||
print(f"\n✗ Error during extraction: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return 1
|
||
finally:
|
||
connection.close()
|
||
print(f"\n✓ Database connection closed")
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == '__main__':
|
||
sys.exit(main())
|