ConceptARC db upload
This commit is contained in:
131
scripts/verify_solution_format.py
Normal file
131
scripts/verify_solution_format.py
Normal file
@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Verify that solutions are stored correctly for different corpora
|
||||
- V1/V2/evaluation: Single grid (one test case)
|
||||
- ConceptArc: Array of grids (multiple test cases)
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import pymysql
|
||||
from dotenv import load_dotenv
|
||||
|
||||
def load_env_config():
|
||||
"""Load database configuration from .env file"""
|
||||
load_dotenv()
|
||||
return {
|
||||
'host': os.getenv('DB_HOST'),
|
||||
'user': os.getenv('DB_USER'),
|
||||
'password': os.getenv('DB_PASSWORD'),
|
||||
'database': os.getenv('DB_NAME'),
|
||||
'port': int(os.getenv('DB_PORT', 3306)),
|
||||
'charset': 'utf8mb4'
|
||||
}
|
||||
|
||||
def check_solution_format(cursor, corpora_name, expected_test_count):
|
||||
"""Check solution format for a specific corpora"""
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Checking {corpora_name} puzzles:")
|
||||
print('='*50)
|
||||
|
||||
cursor.execute(f"""
|
||||
SELECT aj.arc_puzzle_id, aj.json, aj.solution
|
||||
FROM arc_jsons aj
|
||||
JOIN arc_puzzles ap ON aj.arc_puzzle_id = ap.id
|
||||
WHERE ap.corpora = %s
|
||||
LIMIT 5
|
||||
""", (corpora_name,))
|
||||
|
||||
results = cursor.fetchall()
|
||||
|
||||
if not results:
|
||||
print(f"No {corpora_name} puzzles found")
|
||||
return
|
||||
|
||||
mismatches = []
|
||||
|
||||
for puzzle_id, json_str, solution in results:
|
||||
puzzle_data = json.loads(json_str)
|
||||
test_count = len(puzzle_data.get('test', []))
|
||||
|
||||
if not solution:
|
||||
print(f"⚠ {puzzle_id}: No solution stored!")
|
||||
continue
|
||||
|
||||
sol = json.loads(solution)
|
||||
|
||||
# Determine solution structure
|
||||
if isinstance(sol, list) and len(sol) > 0:
|
||||
# Check if it's array of grids or single grid
|
||||
if isinstance(sol[0], list) and len(sol[0]) > 0 and isinstance(sol[0][0], list):
|
||||
# Array of grids (ConceptArc style)
|
||||
sol_count = len(sol)
|
||||
structure = f"Array of {sol_count} grids"
|
||||
else:
|
||||
# Single grid (regular ARC style)
|
||||
sol_count = 1
|
||||
structure = f"Single grid ({len(sol)}x{len(sol[0]) if sol else 0})"
|
||||
else:
|
||||
structure = "Unknown format"
|
||||
sol_count = 0
|
||||
|
||||
match = "✓" if sol_count == test_count else "✗"
|
||||
print(f"{match} {puzzle_id}: {test_count} tests, {structure}")
|
||||
|
||||
if sol_count != test_count:
|
||||
mismatches.append((puzzle_id, test_count, sol_count))
|
||||
|
||||
if mismatches:
|
||||
print(f"\n⚠ Found {len(mismatches)} mismatches:")
|
||||
for pid, expected, actual in mismatches:
|
||||
print(f" {pid}: Expected {expected} solutions, got {actual}")
|
||||
else:
|
||||
print(f"\n✓ All solutions match their test counts!")
|
||||
|
||||
def main():
|
||||
print("Verifying Solution Formats")
|
||||
print("=" * 50)
|
||||
|
||||
config = load_env_config()
|
||||
connection = pymysql.connect(**config)
|
||||
cursor = connection.cursor()
|
||||
|
||||
try:
|
||||
# Check different corpora
|
||||
check_solution_format(cursor, "V1", 1)
|
||||
check_solution_format(cursor, "V2", 1)
|
||||
check_solution_format(cursor, "evaluation", 1)
|
||||
check_solution_format(cursor, "ConceptArc", 3)
|
||||
|
||||
# Summary stats
|
||||
print(f"\n{'='*50}")
|
||||
print("Summary by corpora:")
|
||||
print('='*50)
|
||||
|
||||
cursor.execute("""
|
||||
SELECT ap.corpora,
|
||||
COUNT(*) as total,
|
||||
COUNT(aj.solution) as with_solution
|
||||
FROM arc_puzzles ap
|
||||
JOIN arc_jsons aj ON ap.id = aj.arc_puzzle_id
|
||||
GROUP BY ap.corpora
|
||||
ORDER BY ap.corpora
|
||||
""")
|
||||
|
||||
for corpora, total, with_sol in cursor.fetchall():
|
||||
print(f" {corpora}: {with_sol}/{total} have solutions")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
finally:
|
||||
connection.close()
|
||||
print(f"\n✓ Database connection closed")
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user