# license_core_new.py
"""
New license validation system using asymmetric cryptography.
This module provides secure license validation using RSA public key signatures.
"""

import json
import base64
import hashlib
import os
import sys
from pathlib import Path
from typing import Dict, Any, Tuple
from datetime import datetime

# Cryptography imports
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.backends import default_backend

# Constants
def get_license_resource_path(filename: str) -> str:
    """Get licensing resource path with Nuitka support and cross-platform compatibility"""
    try:
        if getattr(sys, 'frozen', False):
            # Nuitka build - try multiple possible locations
            possible_paths = [
                os.path.join(os.path.dirname(sys.executable), filename),
                os.path.join(os.path.dirname(os.path.abspath(__file__)), filename),
                filename
            ]
            
            for path in possible_paths:
                if os.path.exists(path):
                    return path
        
        # Development environment
        return os.path.join(os.path.dirname(os.path.abspath(__file__)), filename)
    except Exception:
        # Final fallback
        return filename

PUBLIC_KEY_PATH = get_license_resource_path("public_key.pem")
LICENSE_FILE_EXTENSION = ".lic"

def load_public_key() -> Any:
    """
    Load the public key from PEM file
    
    Returns:
        Public key object for signature verification
        
    Raises:
        FileNotFoundError: If public key file is not found
        ValueError: If public key is invalid
    """
    try:
        # Check if public key exists
        public_key_path_obj = Path(PUBLIC_KEY_PATH)
        if not public_key_path_obj.exists():
            raise FileNotFoundError(f"Public key not found at {PUBLIC_KEY_PATH}")
        
        # Load the public key
        with open(public_key_path_obj, "rb") as key_file:
            public_key = serialization.load_pem_public_key(
                key_file.read(),
                backend=default_backend()
            )
        
        return public_key
        
    except Exception as e:
        raise ValueError(f"Failed to load public key: {str(e)}")

def validate_license_signature(license_data: Dict[str, Any], signature: str, public_key: Any) -> bool:
    """
    Validate license data signature using RSA public key
    
    Args:
        license_data: License data to validate
        signature: Base64-encoded signature
        public_key: RSA public key for verification
        
    Returns:
        bool: True if signature is valid, False otherwise
    """
    try:
        # Convert license data to JSON string (same format used for signing)
        license_json = json.dumps(license_data, sort_keys=True, separators=(',', ':'))
        
        # Decode the base64 signature
        signature_bytes = base64.urlsafe_b64decode(signature.encode('utf-8'))
        
        # Verify signature using RSA-PSS padding with SHA-256
        public_key.verify(
            signature_bytes,
            license_json.encode('utf-8'),
            padding.PSS(
                mgf=padding.MGF1(hashes.SHA256()),
                salt_length=padding.PSS.MAX_LENGTH
            ),
            hashes.SHA256()
        )
        
        return True
        
    except Exception as e:
        print(f"Signature verification failed: {str(e)}")
        return False

def validate_license_file_format(license_file_data: Dict[str, Any]) -> Tuple[bool, str]:
    """
    Validate the structure and format of a license file
    
    Args:
        license_file_data: Parsed license file data
        
    Returns:
        tuple[bool, str]: (is_valid, error_message)
    """
    # Check required fields
    required_fields = ["license", "signature", "format", "algorithm"]
    
    for field in required_fields:
        if field not in license_file_data:
            return False, f"Missing required field: {field}"
    
    # Check format
    file_format = license_file_data.get("format")
    if file_format != "AIMMS-1.0":
        return False, f"Unsupported license format: {file_format}"
    
    # Check algorithm
    algorithm = license_file_data.get("algorithm")
    if algorithm != "RSA-PSS-SHA256":
        return False, f"Unsupported signature algorithm: {algorithm}"
    
    # Check license data structure
    license_data = license_file_data.get("license")
    if not isinstance(license_data, dict):
        return False, "Invalid license data structure"
    
    # Check required license fields
    required_license_fields = ["email", "type", "version", "issued_at", "app_version", "features"]
    
    for field in required_license_fields:
        if field not in license_data:
            return False, f"Missing required license field: {field}"
    
    # Check signature format
    signature = license_file_data.get("signature")
    if not isinstance(signature, str) or not signature:
        return False, "Invalid signature format"
    
    return True, "Valid format"

def validate_license_expiration(license_data: Dict[str, Any]) -> Tuple[bool, str]:
    """
    Validate license expiration date
    
    Args:
        license_data: License data to validate
        
    Returns:
        tuple[bool, str]: (is_valid, error_message)
    """
    license_type = license_data.get("type", "")
    features = license_data.get("features", {})
    
    # Permanent licenses don't expire
    if license_type == "permanent":
        return True, "Permanent license - no expiration"
    
    # Trial licenses have expiration
    if license_type == "trial":
        expiration_str = features.get("expiration")
        if not expiration_str:
            return False, "Trial license missing expiration date"
        
        try:
            expiration_date = datetime.fromisoformat(expiration_str.replace("Z", ""))
            current_date = datetime.now()
            
            if current_date > expiration_date:
                return False, "Trial license has expired"
            
            return True, "Trial license is valid"
            
        except Exception as e:
            return False, f"Invalid expiration date format: {str(e)}"
    
    return False, f"Unknown license type: {license_type}"

def validate_license(license_file_path: str) -> Tuple[bool, str, Dict[str, Any]]:
    """
    Validate a license file using asymmetric cryptography
    
    Args:
        license_file_path: Path to the license file
        
    Returns:
        tuple[bool, str, dict]: (is_valid, message, license_data)
    """
    try:
        # Load public key
        public_key = load_public_key()
        
        # Read and parse license file
        with open(license_file_path, "r") as f:
            license_file_data = json.load(f)
        
        # Validate file format
        format_valid, format_message = validate_license_file_format(license_file_data)
        if not format_valid:
            return False, f"Invalid license format: {format_message}", {}
        
        # Extract license data and signature
        license_data = license_file_data["license"]
        signature = license_file_data["signature"]
        
        # Validate signature
        signature_valid = validate_license_signature(license_data, signature, public_key)
        if not signature_valid:
            return False, "Invalid license signature - license may be tampered with", {}
        
        # Validate expiration
        expiration_valid, expiration_message = validate_license_expiration(license_data)
        if not expiration_valid:
            return False, expiration_message, {}
        
        # License is valid
        return True, "License is valid", license_data
        
    except FileNotFoundError:
        return False, f"License file not found: {license_file_path}", {}
    except json.JSONDecodeError:
        return False, "Invalid JSON format in license file", {}
    except Exception as e:
        return False, f"License validation error: {str(e)}", {}

def validate_license_key(email: str, key: str) -> Tuple[bool, str, Dict[str, Any]]:
    """
    Validate license key (wrapper for backward compatibility)
    
    Args:
        email: User's email address
        key: License key (for backward compatibility, not used in new system)
        
    Returns:
        tuple[bool, str, dict]: (is_valid, message, license_data)
    """
    # In the new system, we validate license files, not keys
    # This wrapper is for backward compatibility
    
    # First, check if the key is "asymmetric_license" which indicates a .lic file was activated
    if key == "asymmetric_license":
        # Look for any .lic file that matches this email
        licensing_dir = Path("licensing")
        if licensing_dir.exists():
            # Search for .lic files in licensing directory and subdirectories
            for lic_file in licensing_dir.rglob(f"*{email.split('@')[0]}*.lic"):
                if lic_file.exists():
                    return validate_license(str(lic_file))
            
            # Also check for test license files
            for lic_file in licensing_dir.rglob("*.lic"):
                if lic_file.exists():
                    result = validate_license(str(lic_file))
                    if result[0]:  # If valid
                        license_data = result[2]
                        if license_data.get("email") == email:
                            return result
    
    # For now, we'll look for a license file with the email
    license_file_path = f"licensing/{email.split('@')[0]}_permanent.lic"
    
    if Path(license_file_path).exists():
        return validate_license(license_file_path)
    else:
        # Try trial license
        trial_file_path = f"licensing/{email.split('@')[0]}_trial.lic"
        if Path(trial_file_path).exists():
            return validate_license(trial_file_path)
        else:
            return False, "No license file found for this email", {}

def validate_trial_license(email: str, key: str) -> Tuple[bool, str, bool, bool]:
    """
    Validate trial license with new asymmetric system
    
    Args:
        email: User's email address
        key: License key (not used in new system)
        
    Returns:
        tuple[bool, str, bool, bool]: (is_valid, message, is_permanent, is_expired)
    """
    # Look for trial license file
    trial_file_path = f"licensing/{email.split('@')[0]}_trial.lic"
    
    if Path(trial_file_path).exists():
        is_valid, message, license_data = validate_license(trial_file_path)
        
        if is_valid:
            license_type = license_data.get("type", "")
            
            if license_type == "permanent":
                return True, message, True, False
            elif license_type == "trial":
                # Check if trial is expired
                features = license_data.get("features", {})
                expiration_str = features.get("expiration")
                
                if expiration_str:
                    try:
                        expiration_date = datetime.fromisoformat(expiration_str.replace("Z", ""))
                        current_date = datetime.now()
                        
                        if current_date > expiration_date:
                            return True, "Trial license found but expired", False, True
                        else:
                            days_remaining = (expiration_date - current_date).days
                            return True, f"Trial active - {days_remaining} days remaining", False, False
                    except Exception:
                        return True, "Trial license found", False, False
                
                return True, "Trial license found", False, False
        else:
            return False, message, False, False
    
    # No trial license file found
    return False, "No trial license found", False, False

# Test functions for development
def test_validate_valid_license():
    """Test validation of a valid license"""
    print("Testing valid license validation...")
    result = validate_license("licensing/test_permanent.lic")
    print(f"Result: {result}")
    return result[0]

def test_validate_invalid_license():
    """Test validation of an invalid license"""
    print("Testing invalid license validation...")
    # Create a fake license file
    fake_license = {
        "license": {
            "email": "fake@example.com",
            "type": "permanent",
            "version": "1.0",
            "issued_at": "2026-01-27T00:00:00Z",
            "app_version": "1.0",
            "features": {"all_features": True, "expiration": None}
        },
        "signature": "invalid_signature",
        "format": "AIMMS-1.0",
        "algorithm": "RSA-PSS-SHA256"
    }
    
    with open("licensing/test_invalid.lic", "w") as f:
        json.dump(fake_license, f)
    
    result = validate_license("licensing/test_invalid.lic")
    print(f"Result: {result}")
    return not result[0]

def test_validate_tampered_license():
    """Test validation of a tampered license"""
    print("Testing tampered license validation...")
    
    # Read the valid license
    with open("licensing/test_permanent.lic", "r") as f:
        license_data = json.load(f)
    
    # Tamper with the license data
    license_data["license"]["email"] = "tampered@example.com"
    
    # Write the tampered license
    with open("licensing/test_tampered.lic", "w") as f:
        json.dump(license_data, f)
    
    result = validate_license("licensing/test_tampered.lic")
    print(f"Result: {result}")
    return not result[0]

if __name__ == "__main__":
    print("AIMMS New License Validation System")
    print("=" * 50)
    
    # Run tests
    print("\nRunning validation tests...")
    
    test1 = test_validate_valid_license()
    print(f"Valid license test: {'✅ PASSED' if test1 else '❌ FAILED'}")
    
    test2 = test_validate_invalid_license()
    print(f"Invalid license test: {'✅ PASSED' if test2 else '❌ FAILED'}")
    
    test3 = test_validate_tampered_license()
    print(f"Tampered license test: {'✅ PASSED' if test3 else '❌ FAILED'}")
    
    print("\nAll tests completed!")