116 lines
4.3 KiB
Python
116 lines
4.3 KiB
Python
#!/usr/bin/env python3
|
|
|
|
"""
|
|
Unit tests for entra-id-recon.py
|
|
|
|
These tests verify the core functionality of the reconnaissance and enumeration modules.
|
|
Note: Some tests may require network access and may interact with Microsoft services.
|
|
"""
|
|
|
|
import unittest
|
|
import sys
|
|
import os
|
|
|
|
# Add parent directory to path to import the main module
|
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
|
|
|
# Import the module - note the filename uses hyphens, so we need to import differently
|
|
import importlib.util
|
|
spec = importlib.util.spec_from_file_location("entra_id_recon", os.path.join(os.path.dirname(__file__), '..', 'entra-id-recon.py'))
|
|
entra_id_recon = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(entra_id_recon)
|
|
|
|
# Import functions from the loaded module
|
|
resolve_dns = entra_id_recon.resolve_dns
|
|
get_tenant_id = entra_id_recon.get_tenant_id
|
|
get_tenant_brand_and_sso = entra_id_recon.get_tenant_brand_and_sso
|
|
check_device_code_auth = entra_id_recon.check_device_code_auth
|
|
get_tenant_domains = entra_id_recon.get_tenant_domains
|
|
get_credential_type_info = entra_id_recon.get_credential_type_info
|
|
check_sharepoint = entra_id_recon.check_sharepoint
|
|
check_azure_services = entra_id_recon.check_azure_services
|
|
|
|
|
|
class TestDNSResolution(unittest.TestCase):
|
|
"""Test DNS resolution functionality."""
|
|
|
|
def test_resolve_dns_valid_domain(self):
|
|
"""Test DNS resolution for a valid domain."""
|
|
# This test requires network access
|
|
results = resolve_dns("google.com", "A")
|
|
self.assertIsInstance(results, list)
|
|
|
|
def test_resolve_dns_invalid_domain(self):
|
|
"""Test DNS resolution for an invalid domain."""
|
|
results = resolve_dns("nonexistent-domain-12345.invalid", "A")
|
|
self.assertEqual(results, [])
|
|
|
|
|
|
class TestTenantInformation(unittest.TestCase):
|
|
"""Test tenant information retrieval."""
|
|
|
|
def test_get_tenant_id_invalid_domain(self):
|
|
"""Test tenant ID retrieval for invalid domain."""
|
|
tenant_id, region = get_tenant_id("nonexistent-domain-12345.invalid")
|
|
self.assertIsNone(tenant_id)
|
|
self.assertIsNone(region)
|
|
|
|
def test_get_tenant_brand_invalid_domain(self):
|
|
"""Test tenant brand retrieval for invalid domain."""
|
|
brand, sso = get_tenant_brand_and_sso("nonexistent-domain-12345.invalid")
|
|
self.assertIsNone(brand)
|
|
self.assertIsNone(sso)
|
|
|
|
|
|
class TestDeviceCodeAuth(unittest.TestCase):
|
|
"""Test device code authentication detection."""
|
|
|
|
def test_check_device_code_auth_no_tenant_id(self):
|
|
"""Test device code auth check with no tenant ID."""
|
|
result = check_device_code_auth(None)
|
|
self.assertFalse(result)
|
|
|
|
def test_check_device_code_auth_invalid_tenant_id(self):
|
|
"""Test device code auth check with invalid tenant ID."""
|
|
result = check_device_code_auth("00000000-0000-0000-0000-000000000000")
|
|
# Should return False for invalid tenant ID
|
|
self.assertIsInstance(result, bool)
|
|
|
|
|
|
class TestAzureServices(unittest.TestCase):
|
|
"""Test Azure services detection."""
|
|
|
|
def test_check_azure_services_no_tenant_id(self):
|
|
"""Test Azure services check with no tenant ID."""
|
|
results = check_azure_services("example.com", None)
|
|
self.assertIsInstance(results, dict)
|
|
self.assertEqual(results, {})
|
|
|
|
def test_check_sharepoint_invalid_domain(self):
|
|
"""Test SharePoint detection for invalid domain."""
|
|
result = check_sharepoint("nonexistent-domain-12345.invalid")
|
|
self.assertFalse(result)
|
|
|
|
|
|
class TestUserEnumeration(unittest.TestCase):
|
|
"""Test user enumeration functionality."""
|
|
|
|
def test_get_credential_type_info_invalid_username(self):
|
|
"""Test credential type info for invalid username."""
|
|
result = get_credential_type_info("nonexistent-user-12345@invalid-domain.invalid")
|
|
# Should return None or a valid response structure
|
|
self.assertTrue(result is None or isinstance(result, dict))
|
|
|
|
|
|
class TestDomainRetrieval(unittest.TestCase):
|
|
"""Test domain retrieval functionality."""
|
|
|
|
def test_get_tenant_domains_invalid_domain(self):
|
|
"""Test tenant domains retrieval for invalid domain."""
|
|
result = get_tenant_domains("nonexistent-domain-12345.invalid")
|
|
self.assertIsNone(result)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|