"""API Client for EKZ Tariff API."""
from __future__ import annotations

import logging
from datetime import datetime, timedelta, timezone
from typing import Any

import aiohttp

_LOGGER = logging.getLogger(__name__)

# Korrekter API Endpunkt
API_BASE_URL = "https://api.tariffs.ekz.ch"
API_TARIFFS_ENDPOINT = "/v1/tariffs"


class EKZTariffAPIError(Exception):
    """Exception for EKZ API errors."""
    pass


class EKZTariffAPI:
    """API client for EKZ Tariff service."""
    
    def __init__(self, session: aiohttp.ClientSession) -> None:
        """Initialize the API client."""
        self._session = session
        self._cache: dict[str, Any] = {}
        self._cache_time: datetime | None = None
    
    async def get_tariffs(
        self, 
        use_cache: bool = True,
        cache_hours: int = 1
    ) -> dict[str, Any]:
        """
        Fetch tariffs from EKZ API.
        
        Returns:
            Dict with 'publication_timestamp' and 'prices' list
        """
        # Check cache
        if use_cache and self._cache and self._cache_time:
            cache_age = datetime.now(timezone.utc) - self._cache_time
            if cache_age < timedelta(hours=cache_hours):
                _LOGGER.debug("Using cached tariff data (age: %s)", cache_age)
                return self._cache
        
        # Build URL
        url = f"{API_BASE_URL}{API_TARIFFS_ENDPOINT}"
        
        try:
            _LOGGER.debug("Fetching tariffs from %s", url)
            
            async with self._session.get(
                url, 
                timeout=aiohttp.ClientTimeout(total=30)
            ) as response:
                if response.status == 200:
                    data = await response.json()
                    
                    # Parse and structure the data
                    parsed = self._parse_response(data)
                    
                    # Update cache
                    self._cache = parsed
                    self._cache_time = datetime.now(timezone.utc)
                    
                    _LOGGER.debug(
                        "Successfully fetched %d price slots", 
                        len(parsed.get("prices", []))
                    )
                    return parsed
                    
                else:
                    text = await response.text()
                    raise EKZTariffAPIError(
                        f"API request failed with status {response.status}: {text}"
                    )
                    
        except aiohttp.ClientError as err:
            raise EKZTariffAPIError(f"Connection error: {err}") from err
        except Exception as err:
            if isinstance(err, EKZTariffAPIError):
                raise
            raise EKZTariffAPIError(f"Unexpected error: {err}") from err
    
    def _parse_response(self, data: dict[str, Any]) -> dict[str, Any]:
        """
        Parse the API response.
        
        API Format:
        {
            "publication_timestamp": "2025-12-18T14:12:49+01:00",
            "prices": [
                {
                    "start_timestamp": "2026-02-04T00:00:00+01:00",
                    "end_timestamp": "2026-02-04T00:15:00+01:00",
                    "electricity": [
                        {"unit": "CHF_m", "value": 3.00000},
                        {"unit": "CHF_kWh", "value": 0.13300}
                    ],
                    "grid": [
                        {"unit": "CHF_m", "value": 0.00000},
                        {"unit": "CHF_kWh", "value": 0.10980}
                    ],
                    "integrated": [
                        {"unit": "CHF_m", "value": 3.00000},
                        {"unit": "CHF_kWh", "value": 0.24280}
                    ],
                    ...
                }
            ]
        }
        """
        prices = []
        
        for item in data.get("prices", []):
            try:
                # Parse timestamps
                start = datetime.fromisoformat(item["start_timestamp"])
                end = datetime.fromisoformat(item["end_timestamp"])
                
                # Extract CHF/kWh prices
                electricity_price = self._get_kwh_price(item.get("electricity", []))
                grid_price = self._get_kwh_price(item.get("grid", []))
                integrated_price = self._get_kwh_price(item.get("integrated", []))
                regional_fees = self._get_kwh_price(item.get("regional_fees", []))
                
                prices.append({
                    "start": start,
                    "end": end,
                    "electricity_price": electricity_price,
                    "grid_price": grid_price,
                    "integrated_price": integrated_price,
                    "regional_fees": regional_fees,
                    # Total inkl. regionale Gebühren
                    "total_price": (integrated_price or 0) + (regional_fees or 0),
                })
                
            except (KeyError, ValueError, TypeError) as err:
                _LOGGER.warning("Failed to parse price item: %s - %s", item, err)
                continue
        
        # Sort by start time
        prices.sort(key=lambda x: x["start"])
        
        return {
            "publication_timestamp": data.get("publication_timestamp"),
            "prices": prices,
        }
    
    def _get_kwh_price(self, price_list: list[dict[str, Any]]) -> float | None:
        """Extract the CHF/kWh price from a price list."""
        for entry in price_list:
            if entry.get("unit") == "CHF_kWh":
                return float(entry.get("value", 0))
        return None
    
    def get_current_price(self, prices: list[dict[str, Any]]) -> dict[str, Any] | None:
        """Get the current price based on current time."""
        now = datetime.now().astimezone()
        
        for price in prices:
            if price["start"] <= now < price["end"]:
                return price
        
        return None
    
    def get_prices_for_date(
        self, 
        prices: list[dict[str, Any]], 
        date: datetime
    ) -> list[dict[str, Any]]:
        """Get all prices for a specific date."""
        target_date = date.date()
        return [
            price for price in prices 
            if price["start"].date() == target_date
        ]
    
    def calculate_statistics(
        self, 
        prices: list[dict[str, Any]],
        price_key: str = "integrated_price"
    ) -> dict[str, float | None]:
        """Calculate min, max, and average from prices."""
        if not prices:
            return {"min": None, "max": None, "avg": None}
        
        values = [p[price_key] for p in prices if p.get(price_key) is not None]
        
        if not values:
            return {"min": None, "max": None, "avg": None}
        
        return {
            "min": min(values),
            "max": max(values),
            "avg": sum(values) / len(values),
        }
