import redis
import json
import pandas as pd
import joblib
import os
import argparse
from datetime import datetime, timedelta
from sklearn.ensemble import IsolationForest

from flow_ai_module import analyze_and_tag as analyze_flow
from http_ai_module import analyze_and_tag as analyze_http
from dns_ai_module import analyze_and_tag as analyze_dns
from tls_ai_module import analyze_and_tag as analyze_tls

ANALYZERS = {
    "flow": analyze_flow,
    "http": analyze_http,
    "dns": analyze_dns,
    "tls": analyze_tls
}

STREAMS = ["flow", "dns", "http", "tls"]

MODEL_PATH = "meta_model.pkl"
META_FEATURES_FILE = "meta_model_features.txt"  # NEW: schema persistence for meta features
FEATURE_LOG = "meta_features_log.csv"


# ---------------------- Helpers ----------------------

def _align_features_to_model(df: pd.DataFrame, model, fallback_cols=None) -> pd.DataFrame:
    """
    Align DataFrame columns to what the model expects (feature_names_in_).
    If not available, align to fallback_cols (saved schema), otherwise return df as-is.
    """
    # 1) Most reliable: schema embedded in the trained sklearn model
    expected = getattr(model, "feature_names_in_", None)
    if expected is not None:
        for col in expected:
            if col not in df.columns:
                df[col] = 0
        return df.loc[:, list(expected)].fillna(0)

    # 2) Fallback: locally saved schema (for older sklearn or edge cases)
    if fallback_cols:
        for col in fallback_cols:
            if col not in df.columns:
                df[col] = 0
        return df.loc[:, fallback_cols].fillna(0)

    # 3) Last resort: just fill NaNs
    return df.fillna(0)


# ---------------------- Meta feature extraction ----------------------

from datetime import datetime, timedelta, timezone

def fetch_events(redis_conn, event_type, minutes=60, shift_minutes=0):
    # current time in UTC with tzinfo
    now = datetime.now(timezone.utc) - timedelta(minutes=shift_minutes)
    threshold = now - timedelta(minutes=minutes)

    all_data = redis_conn.lrange(event_type, 0, -1)
    filtered = []

    for raw in all_data:
        try:
            e = json.loads(raw)
            ts = e.get("timestamp")
            if not ts:
                continue

            # parse ISO timestamp; make it UTC-aware if it's naive
            dt = datetime.fromisoformat(ts)
            if dt.tzinfo is None:
                dt = dt.replace(tzinfo=timezone.utc)

            if dt >= threshold:
                filtered.append(e)
        except Exception:
            continue

    return filtered


def extract_meta_features(redis_conn, minutes=60, shift=0):
    """
    Build a single-row DataFrame with aggregated features for the meta model.
    Features include:
      - count_<stream>_total
      - count_<stream>_anomalies
      - unique_src_ips (from flow)
    """
    result = {}

    for t in STREAMS:
        events = fetch_events(redis_conn, t, minutes, shift_minutes=shift)
        result[f"count_{t}_total"] = len(events)

        if not events:
            result[f"count_{t}_anomalies"] = 0
            continue

        model_path = f"{t}_model.pkl"
        if not os.path.exists(model_path):
            print(f"[!] Model not found: {model_path}")
            result[f"count_{t}_anomalies"] = 0
            continue

        try:
            model = joblib.load(model_path)
            analyzer = ANALYZERS[t]
            anomalies, _ = analyzer(events, model)
            result[f"count_{t}_anomalies"] = len(anomalies)
        except Exception as e:
            print(f"[!] Analysis error {t}: {e}")
            result[f"count_{t}_anomalies"] = 0

    # Additional meta signals
    flow = fetch_events(redis_conn, "flow", minutes, shift_minutes=shift)
    ips = [e.get("src_ip") for e in flow if e.get("src_ip")]
    result["unique_src_ips"] = len(set(ips))

    # Ensure numeric stability (avoid types that may break sklearn)
    for k, v in list(result.items()):
        if v is None:
            result[k] = 0
        elif isinstance(v, bool):
            result[k] = int(v)
        elif isinstance(v, (int, float)):
            pass
        else:
            try:
                result[k] = float(v)
            except Exception:
                result[k] = 0

    return pd.DataFrame([result])


def analyze_snapshot(redis_conn, model, shift=0, minutes=60):
    df = extract_meta_features(redis_conn, minutes=minutes, shift=shift)
    if df.empty:
        return False, {}

    # Align to the model’s schema (or saved fallback schema)
    fallback_cols = None
    if os.path.exists(META_FEATURES_FILE):
        with open(META_FEATURES_FILE) as f:
            fallback_cols = [line.strip() for line in f if line.strip()]

    df = _align_features_to_model(df, model, fallback_cols=fallback_cols)

    prediction = model.predict(df)[0]
    return prediction == -1, df.to_dict(orient="records")[0]


# ---------------------- Model I/O ----------------------

def train_model(df):
    model = IsolationForest(n_estimators=100, contamination=0.05, random_state=42)
    model.fit(df)
    return model


def save_model(model, path=MODEL_PATH):
    joblib.dump(model, path)
    print(f"[+] Meta model saved in {path}")


def load_model(path=MODEL_PATH):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Model not found: {path}")
    return joblib.load(path)


# ---------------------- Multi-anomaly detection ----------------------

from datetime import datetime, timedelta, timezone

def detect_multi_anomalies(redis_conn, interval='minute', minutes=60):
    # tz-aware "now" in UTC
    now = datetime.now(timezone.utc)
    start = now - timedelta(minutes=minutes)
    buckets = {}

    for stream in STREAMS:
        raw = redis_conn.lrange(stream, 0, -1)
        for line in raw:
            try:
                e = json.loads(line)
                ts = e.get("timestamp")
                if not ts or not e.get("anomaly"):
                    continue

                # parse ISO timestamp; make it UTC-aware if naive
                dt = datetime.fromisoformat(ts)
                if dt.tzinfo is None:
                    dt = dt.replace(tzinfo=timezone.utc)

                if dt < start:
                    continue

                # normalize to bucket boundary
                if interval == "minute":
                    key = dt.replace(second=0, microsecond=0)
                elif interval == "hour":
                    key = dt.replace(minute=0, second=0, microsecond=0)
                elif interval == "day":
                    key = dt.replace(hour=0, minute=0, second=0, microsecond=0)
                else:
                    continue  # unsupported interval

                buckets.setdefault(key, set()).add(stream)
            except Exception:
                continue

    result = []
    for bucket_time, streams_set in sorted(buckets.items()):
        count = len(streams_set)
        if count < 2:
            continue  # multi-anomaly = at least 2 streams

        level = {2: 1, 3: 2, 4: 3}.get(count, 1)
        result.append({
            "time": bucket_time.strftime("%Y-%m-%d %H:%M"),
            "types": sorted(list(streams_set)),
            "count": count,
            "level": level
        })

    return result


def log_multi_anomaly_to_redis(redis_conn, anomalies, redis_key="multianomalies", max_check=5000):
    try:
        existing = redis_conn.lrange(redis_key, 0, max_check)
    except Exception as e:
        print(f"[!] Error reading Redis: {e}")
        return

    existing_keys = set()
    for row in existing:
        try:
            obj = json.loads(row)
            key = (obj.get("time"), tuple(sorted(obj.get("types", []))))
            existing_keys.add(key)
        except Exception:
            continue

    added = 0
    for entry in anomalies:
        key = (entry["time"], tuple(sorted(entry["types"])))
        if key in existing_keys:
            continue

        entry_to_store = entry.copy()
        entry_to_store["escalated"] = False
        try:
            redis_conn.lpush(redis_key, json.dumps(entry_to_store))
            added += 1
        except Exception as e:
            print(f"[!] Error writing to Redis: {e}")

    print(f"[+] New multi-anomalies added: {added}")


# ---------------------- CLI ----------------------

def main(args=None):
    if args is None:
        parser = argparse.ArgumentParser(description="Meta analysis for Suricata")
        parser.add_argument('--train', action='store_true')
        parser.add_argument('--check', action='store_true')
        parser.add_argument('--windows', type=int, default=200)
        parser.add_argument('--window_size', type=int, default=60)
        parser.add_argument('--shift', type=int, default=0)
        args = parser.parse_args()

    r = redis.Redis(host="localhost", port=6379, decode_responses=True)

    if args.train:
        records = []
        for i in range(args.windows):
            df = extract_meta_features(r, minutes=args.window_size, shift=i)
            if not df.empty:
                records.append(df.iloc[0])

        df_all = pd.DataFrame(records)

        # NEW: ensure numeric stability and no NaNs before training
        df_all = df_all.fillna(0)
        for c in df_all.columns:
            if not pd.api.types.is_numeric_dtype(df_all[c]):
                df_all[c] = pd.to_numeric(df_all[c], errors="coerce").fillna(0)

        df_all.to_csv(FEATURE_LOG, index=False)
        model = train_model(df_all)
        save_model(model)

        # NEW: persist the meta feature schema for inference alignment
        with open(META_FEATURES_FILE, "w") as f:
            f.write("\n".join(df_all.columns))

    elif args.check:
        model = load_model()
        is_anomaly, features = analyze_snapshot(r, model, shift=args.shift, minutes=args.window_size)
        print(f"🔥 Meta-anomaly (shift {args.shift} min): {is_anomaly}")
        print("📊 Snapshot of behavior:")
        print(json.dumps(features, indent=2, ensure_ascii=False))


if __name__ == '__main__':
    main()
