#!/bin/bash
set -euo pipefail

SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)"
INPUT_FILE="/srv/platform/docker-compose.yml"
OUTPUT_FILE="/srv/platform/oidc-providers.json"
PROVIDERS_CONFIG_PATH="/etc/polymatica/platform/manager/oidc-providers.json"

show_usage() {
  cat <<EOF
Usage: $(basename "$0") [OPTIONS]

Migrate Docker OIDC settings into oidc-providers.json and update docker-compose.yml.
The migration runs only once: if oidc-providers.json already exists, the script exits.

Options:
  -i, --input <path>    Source docker-compose.yml (default: ${INPUT_FILE})
  -o, --output <path>   Destination oidc-providers.json (default: ${OUTPUT_FILE})
  -h, --help            Show this help
EOF
}

while [[ $# -gt 0 ]]; do
  case "$1" in
    -i|--input)
      INPUT_FILE="$2"
      shift 2
      ;;
    -o|--output)
      OUTPUT_FILE="$2"
      shift 2
      ;;
    -h|--help)
      show_usage
      exit 0
      ;;
    *)
      echo "Unknown option: $1" >&2
      show_usage >&2
      exit 1
      ;;
  esac
done

if [[ ! -f "${INPUT_FILE}" ]]; then
  echo "Input file not found: ${INPUT_FILE}" >&2
  exit 1
fi

if [[ -f "${OUTPUT_FILE}" ]]; then
  echo "OIDC settings have already been migrated: ${OUTPUT_FILE}"
  exit 0
fi

if ! command -v python3 >/dev/null 2>&1; then
  echo "python3 is required to parse the source config" >&2
  exit 1
fi

mkdir -p "$(dirname "${OUTPUT_FILE}")"

python3 - "${INPUT_FILE}" "${OUTPUT_FILE}" "${PROVIDERS_CONFIG_PATH}" <<'PY'
import json
import re
import sys
from pathlib import Path

input_path = Path(sys.argv[1])
output_path = Path(sys.argv[2])
providers_config_path = sys.argv[3]
text = input_path.read_text(encoding="utf-8")


def normalize_scalar(value):
    if isinstance(value, bool):
        return value
    if value is None:
        return ""
    if isinstance(value, (int, float)):
        return value
    if not isinstance(value, str):
        return value

    value = value.strip()
    if len(value) >= 2 and value[0] == value[-1] and value[0] in {"'", '"'}:
        value = value[1:-1]

    lowered = value.lower()
    if lowered == "true":
        return True
    if lowered == "false":
        return False
    return value


def build_provider_from_env(raw_text):
    provider = {}
    claims = {}
    claim_keys = {
        "preferred_username_key",
        "email_key",
        "first_name_key",
        "middle_name_key",
        "last_name_key",
        "group_key",
        "role_key",
    }

    patterns = [
        re.compile(r"^\s*-\s*POLYMATICA_CORE_OIDC_([A-Z0-9_]+)=(.*)\s*$"),
        re.compile(r"^\s*POLYMATICA_CORE_OIDC_([A-Z0-9_]+)=(.*)\s*$"),
    ]

    for line in raw_text.splitlines():
        match = None
        for pattern in patterns:
            match = pattern.match(line)
            if match:
                break
        if not match:
            continue

        suffix = match.group(1).lower()
        if suffix == "providers_config_path":
            continue

        value = normalize_scalar(match.group(2))

        if suffix.startswith("provider_claims_"):
            claims[suffix[len("provider_claims_"):]] = value
            continue

        if suffix in claim_keys:
            claims[suffix] = value
            continue

        if suffix in {"provider_logo", "logo"}:
            provider["provider_icon"] = value
            continue

        if suffix == "provider_icon":
            provider["provider_icon"] = value
            continue

        provider[suffix] = value

    if not provider:
        raise SystemExit("No legacy POLYMATICA_CORE_OIDC_* variables found in source file")

    if claims:
        provider["provider_claims"] = claims

    return finalize_provider(provider)


def find_oidc_section(value):
    if isinstance(value, dict):
        if "oidc" in value:
            return value["oidc"]
        for nested in value.values():
            found = find_oidc_section(nested)
            if found is not None:
                return found
    elif isinstance(value, list):
        for item in value:
            found = find_oidc_section(item)
            if found is not None:
                return found
    return None


def finalize_provider(provider):
    provider.setdefault("insecure_override_issuer", False)
    provider.setdefault("provider_actual_issuer", "")
    provider.setdefault("provider_claims", {})
    provider.setdefault("provider_icon", "")

    provider_name = provider.get("provider_name")
    if provider_name:
        provider["provider_type"] = provider_name

    return provider


def normalize_provider_dict(raw_provider, default_name=None, enabled=None):
    provider = {}
    claims = {}

    key_map = {
        "auth_method_active": "auth_method_active",
        "enabled": "auth_method_active",
        "insecure_override_issuer": "insecure_override_issuer",
        "issuer": "provider_issuer",
        "provider_issuer": "provider_issuer",
        "actual_issuer": "provider_actual_issuer",
        "provider_actual_issuer": "provider_actual_issuer",
        "redirect_url": "provider_redirect_url",
        "provider_redirect_url": "provider_redirect_url",
        "logout_url": "provider_logout_url",
        "provider_logout_url": "provider_logout_url",
        "client_id": "provider_client_id",
        "provider_client_id": "provider_client_id",
        "client_secret": "provider_client_secret",
        "provider_client_secret": "provider_client_secret",
        "scope": "provider_scope",
        "provider_scope": "provider_scope",
        "name": "provider_name",
        "provider_name": "provider_name",
        "logo": "provider_icon",
        "provider_logo": "provider_icon",
        "icon": "provider_icon",
        "provider_icon": "provider_icon",
        "authorize_iam_url": "provider_authorize_iam_url",
        "provider_authorize_iam_url": "provider_authorize_iam_url",
        "use_token_auth": "provider_use_token_auth",
        "provider_use_token_auth": "provider_use_token_auth",
        "skip_token_check": "provider_skip_token_check",
        "provider_skip_token_check": "provider_skip_token_check",
        "auto_create_user": "provider_auto_create_user",
        "provider_auto_create_user": "provider_auto_create_user",
        "groups_whitelist": "provider_groups_whitelist",
        "provider_groups_whitelist": "provider_groups_whitelist",
    }
    claim_keys = {
        "preferred_username_key",
        "email_key",
        "first_name_key",
        "middle_name_key",
        "last_name_key",
        "group_key",
        "role_key",
    }

    for key, value in raw_provider.items():
        normalized_key = str(key).lower()

        if normalized_key in {"provider_claims", "claims"} and isinstance(value, dict):
            for claim_key, claim_value in value.items():
                claims[str(claim_key).lower()] = normalize_scalar(claim_value)
            continue

        if normalized_key.startswith("claims_"):
            claims[normalized_key[len("claims_"):]] = normalize_scalar(value)
            continue

        if normalized_key.startswith("provider_claims_"):
            claims[normalized_key[len("provider_claims_"):]] = normalize_scalar(value)
            continue

        if normalized_key in claim_keys:
            claims[normalized_key] = normalize_scalar(value)
            continue

        target_key = key_map.get(normalized_key)
        if target_key:
            provider[target_key] = normalize_scalar(value)
            continue

        if isinstance(value, (dict, list)):
            continue

        provider[normalized_key] = normalize_scalar(value)

    if enabled is not None and "auth_method_active" not in provider:
        provider["auth_method_active"] = normalize_scalar(enabled)

    if default_name and "provider_name" not in provider:
        provider["provider_name"] = default_name

    if claims:
        provider["provider_claims"] = claims

    return finalize_provider(provider)


def build_providers_from_oidc(oidc_section):
    if isinstance(oidc_section, list):
        providers = [normalize_provider_dict(item) for item in oidc_section if isinstance(item, dict)]
        if providers:
            return providers

    if not isinstance(oidc_section, dict):
        raise SystemExit('The "oidc" section must be an object or list')

    if "providers" in oidc_section:
        providers_section = oidc_section["providers"]
        if isinstance(providers_section, list):
            providers = [
                normalize_provider_dict(item, enabled=oidc_section.get("enabled"))
                for item in providers_section
                if isinstance(item, dict)
            ]
            if providers:
                return providers
        if isinstance(providers_section, dict):
            providers = []
            for provider_name, provider_value in providers_section.items():
                if isinstance(provider_value, dict):
                    providers.append(
                        normalize_provider_dict(
                            provider_value,
                            default_name=str(provider_name),
                            enabled=oidc_section.get("enabled"),
                        )
                    )
            if providers:
                return providers

    if "provider" in oidc_section and isinstance(oidc_section["provider"], dict):
        return [normalize_provider_dict(oidc_section["provider"], enabled=oidc_section.get("enabled"))]

    known_final_keys = {
        "auth_method_active",
        "provider_name",
        "provider_issuer",
        "provider_client_id",
        "provider_client_secret",
        "issuer",
        "client_id",
        "client_secret",
    }
    if any(key in oidc_section for key in known_final_keys):
        return [normalize_provider_dict(oidc_section)]

    providers = []
    for provider_name, provider_value in oidc_section.items():
        if isinstance(provider_value, dict):
            providers.append(
                normalize_provider_dict(
                    provider_value,
                    default_name=str(provider_name),
                    enabled=oidc_section.get("enabled"),
                )
            )
    if providers:
        return providers

    raise SystemExit('Could not normalize the "oidc" section into oidc-providers.json format')


def update_compose(raw_text, config_path):
    lines = raw_text.splitlines(keepends=True)
    newline = "\r\n" if "\r\n" in raw_text else "\n"

    manager_index = None
    for index, line in enumerate(lines):
        if re.match(r"^\s{2}manager:\s*$", line.rstrip("\r\n")):
            manager_index = index
            break
    if manager_index is None:
        raise SystemExit('Could not find "manager" service in docker-compose.yml')

    service_end = len(lines)
    for index in range(manager_index + 1, len(lines)):
        if re.match(r"^\s{2}[A-Za-z0-9_-]+:\s*$", lines[index].rstrip("\r\n")):
            service_end = index
            break

    env_index = None
    for index in range(manager_index + 1, service_end):
        if re.match(r"^\s{4}environment:\s*$", lines[index].rstrip("\r\n")):
            env_index = index
            break
    if env_index is None:
        raise SystemExit('Could not find "environment" section for "manager" service')

    env_end = service_end
    for index in range(env_index + 1, service_end):
        stripped = lines[index].rstrip("\r\n")
        if re.match(r"^\s{4}[A-Za-z0-9_-]+:\s*", stripped):
            env_end = index
            break

    env_block = lines[env_index + 1:env_end]
    filtered_block = []
    has_config_path = False

    for line in env_block:
        stripped = line.strip()
        match = re.match(r"^-\s*(POLYMATICA_CORE_OIDC_[A-Z0-9_]+)=(.*)$", stripped)
        if not match:
            filtered_block.append(line)
            continue

        env_name = match.group(1)
        if env_name == "POLYMATICA_CORE_OIDC_PROVIDERS_CONFIG_PATH":
            filtered_block.append(f"      - {env_name}={config_path}{newline}")
            has_config_path = True
        else:
            continue

    if not has_config_path:
        filtered_block.append(f"      - POLYMATICA_CORE_OIDC_PROVIDERS_CONFIG_PATH={config_path}{newline}")

    updated_lines = lines[:env_index + 1] + filtered_block + lines[env_end:]
    return "".join(updated_lines)


if "POLYMATICA_CORE_OIDC_" in text:
    providers = [build_provider_from_env(text)]
    input_path.write_text(update_compose(text, providers_config_path), encoding="utf-8")
else:
    try:
        data = json.loads(text)
    except json.JSONDecodeError as exc:
        raise SystemExit(f"Unsupported source format: {exc}") from exc

    oidc_section = find_oidc_section(data)
    if oidc_section is None:
        raise SystemExit('Could not find an "oidc" section in the source file')

    providers = build_providers_from_oidc(oidc_section)

output_path.write_text(json.dumps(providers, indent=4, ensure_ascii=False) + "\n", encoding="utf-8")
PY

echo "OIDC settings migrated to ${OUTPUT_FILE}"
