#!/usr/bin/env python3
"""
Script pour ajouter ou mettre à jour les villes de Côte d'Ivoire
dans la base de données TIMCASH.
"""

import os
from pathlib import Path
import sys
from dotenv import load_dotenv

from models.models import City

# Charger les variables d'environnement
load_dotenv()

# Ajouter le dossier backend au path pour importer tes modules
sys.path.append(str(Path(__file__).resolve().parents[1]))

# Importer SQLAlchemy
from sqlalchemy import create_engine, Column, String, Integer, ForeignKey, func
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.exc import IntegrityError

# Import settings si tu as une config centralisée
try:
    from core.config import settings
    DATABASE_URL = settings.DATABASE_URL
except ImportError:
    DATABASE_URL = os.getenv(
        "DATABASE_URL",
        "postgresql+psycopg2://timcash:timcash123@localhost:5432/timcash_db"
    )

# === Définition de la base et session SQLAlchemy ===
Base = declarative_base()
engine = create_engine(DATABASE_URL, echo=False)
SessionLocal = sessionmaker(bind=engine)


# === Modèle City minimal pour ce script ===
# Créer les tables si elles n’existent pas
Base.metadata.create_all(bind=engine)


# === Liste des villes de Côte d'Ivoire ===
ci_cities = [
    {"name": "Abidjan", "branch_code": "Abidjan"},
    {"name": "Yamoussoukro", "branch_code": "Yamoussoukro"},
    {"name": "Bouaké", "branch_code": "Bouaké"},
    # ... ajoute ici toutes les autres villes
]

# === Script principal ===
def main():
    session = SessionLocal()
    added = 0
    updated = 0

    try:
        for city_data in ci_cities:
            city = session.query(City).filter_by(name=city_data["name"]).first()
            if city:
                updated += 1
            else:
                # Ajout
                city = City(**city_data)
                session.add(city)
                added += 1
        session.commit()

        print(f"\n📊 Résumé:")
        print(f"   ✅ Ajoutées: {added}")
        print(f"   🔄 Mises à jour: {updated}")
        print(f"   📋 Total traité: {len(ci_cities)}")

        # Résumé par région
        print("\n📍 Répartition par région :")
        regions = session.query(
            City.name, func.count(City.id).label("total_cities")
        ).group_by(City.name).order_by(City).all()

        for r in regions:
            print(f"   🌍 Région: {r} - Total villes: {r.total_cities}")

    except IntegrityError as e:
        print("❌ Erreur lors de l'ajout:", e)
        session.rollback()
    finally:
        session.close()


if __name__ == "__main__":
    main()
