ConceptARC db upload
This commit is contained in:
155
scripts/fix_conceptarc_solutions.py
Executable file
155
scripts/fix_conceptarc_solutions.py
Executable file
@ -0,0 +1,155 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fix ConceptArc solutions in the database to include ALL test outputs, not just the last one
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import pymysql
|
||||
from dotenv import load_dotenv
|
||||
|
||||
def load_env_config():
|
||||
"""Load database configuration from .env file"""
|
||||
load_dotenv()
|
||||
return {
|
||||
'host': os.getenv('DB_HOST'),
|
||||
'user': os.getenv('DB_USER'),
|
||||
'password': os.getenv('DB_PASSWORD'),
|
||||
'database': os.getenv('DB_NAME'),
|
||||
'port': int(os.getenv('DB_PORT', 3306)),
|
||||
'charset': 'utf8mb4'
|
||||
}
|
||||
|
||||
def main():
|
||||
print("Fixing ConceptArc Solutions")
|
||||
print("=" * 50)
|
||||
|
||||
config = load_env_config()
|
||||
connection = pymysql.connect(**config)
|
||||
cursor = connection.cursor()
|
||||
|
||||
try:
|
||||
# Get all ConceptArc entries
|
||||
cursor.execute("""
|
||||
SELECT aj.id, aj.arc_puzzle_id, aj.json, aj.solution
|
||||
FROM arc_jsons aj
|
||||
JOIN arc_puzzles ap ON aj.arc_puzzle_id = ap.id
|
||||
WHERE ap.corpora = 'ConceptArc'
|
||||
ORDER BY aj.arc_puzzle_id
|
||||
""")
|
||||
|
||||
entries = cursor.fetchall()
|
||||
print(f"Found {len(entries)} ConceptArc entries to check")
|
||||
|
||||
if not entries:
|
||||
print("No ConceptArc entries found!")
|
||||
return 0
|
||||
|
||||
# Ask for confirmation
|
||||
if '--yes' not in sys.argv:
|
||||
response = input(f"\nUpdate solutions for {len(entries)} entries? (yes/no): ").strip().lower()
|
||||
if response not in ['yes', 'y']:
|
||||
print("Operation cancelled")
|
||||
return 0
|
||||
else:
|
||||
print(f"Updating solutions (auto-confirmed with --yes flag)")
|
||||
|
||||
updated = 0
|
||||
errors = 0
|
||||
skipped = 0
|
||||
|
||||
print("\nProcessing entries...")
|
||||
|
||||
for row_id, puzzle_id, json_str, current_solution in entries:
|
||||
try:
|
||||
# Parse the puzzle JSON
|
||||
puzzle_data = json.loads(json_str)
|
||||
|
||||
# Extract all test outputs
|
||||
all_outputs = []
|
||||
if 'test' in puzzle_data:
|
||||
for test_case in puzzle_data['test']:
|
||||
if 'output' in test_case:
|
||||
all_outputs.append(test_case['output'])
|
||||
|
||||
if not all_outputs:
|
||||
print(f" ⚠ {puzzle_id}: No test outputs found")
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Create new solution as array of all outputs
|
||||
new_solution = json.dumps(all_outputs)
|
||||
|
||||
# Check if it's different from current
|
||||
if current_solution == new_solution:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Update the solution
|
||||
cursor.execute("""
|
||||
UPDATE arc_jsons
|
||||
SET solution = %s
|
||||
WHERE id = %s
|
||||
""", (new_solution, row_id))
|
||||
|
||||
updated += 1
|
||||
|
||||
if updated % 20 == 0:
|
||||
print(f" Updated: {updated}/{len(entries)}")
|
||||
|
||||
except Exception as e:
|
||||
errors += 1
|
||||
print(f" ✗ Error with {puzzle_id}: {e}")
|
||||
if errors > 10:
|
||||
print("Too many errors, stopping...")
|
||||
break
|
||||
|
||||
# Commit changes
|
||||
connection.commit()
|
||||
|
||||
print(f"\n{'=' * 50}")
|
||||
print(f"✓ Solution update complete!")
|
||||
print(f" Updated: {updated}")
|
||||
print(f" Skipped (unchanged): {skipped}")
|
||||
print(f" Errors: {errors}")
|
||||
|
||||
# Show a sample of updated solutions
|
||||
if updated > 0:
|
||||
cursor.execute("""
|
||||
SELECT aj.arc_puzzle_id, aj.solution
|
||||
FROM arc_jsons aj
|
||||
JOIN arc_puzzles ap ON aj.arc_puzzle_id = ap.id
|
||||
WHERE ap.corpora = 'ConceptArc'
|
||||
LIMIT 3
|
||||
""")
|
||||
|
||||
print(f"\n{'=' * 50}")
|
||||
print("Sample updated entries:")
|
||||
print("=" * 50)
|
||||
|
||||
for puzzle_id, solution in cursor.fetchall():
|
||||
if solution:
|
||||
sol_data = json.loads(solution)
|
||||
print(f"\n{puzzle_id}:")
|
||||
print(f" Number of test outputs: {len(sol_data)}")
|
||||
if isinstance(sol_data, list) and len(sol_data) > 0:
|
||||
first_output = sol_data[0]
|
||||
if isinstance(first_output, list):
|
||||
print(f" First output dimensions: {len(first_output)}x{len(first_output[0]) if first_output else 0}")
|
||||
|
||||
except Exception as e:
|
||||
connection.rollback()
|
||||
print(f"\n✗ Error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return 1
|
||||
finally:
|
||||
connection.close()
|
||||
print(f"\n✓ Database connection closed")
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user