Files
entra-id-recon.py/tests/test_entra_id_recon.py
Warezpeddler 3e873f03a8 Initial commit
2026-01-29 00:15:12 +00:00

116 lines
4.3 KiB
Python

#!/usr/bin/env python3
"""
Unit tests for entra-id-recon.py
These tests verify the core functionality of the reconnaissance and enumeration modules.
Note: Some tests may require network access and may interact with Microsoft services.
"""
import unittest
import sys
import os
# Add parent directory to path to import the main module
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# Import the module - note the filename uses hyphens, so we need to import differently
import importlib.util
spec = importlib.util.spec_from_file_location("entra_id_recon", os.path.join(os.path.dirname(__file__), '..', 'entra-id-recon.py'))
entra_id_recon = importlib.util.module_from_spec(spec)
spec.loader.exec_module(entra_id_recon)
# Import functions from the loaded module
resolve_dns = entra_id_recon.resolve_dns
get_tenant_id = entra_id_recon.get_tenant_id
get_tenant_brand_and_sso = entra_id_recon.get_tenant_brand_and_sso
check_device_code_auth = entra_id_recon.check_device_code_auth
get_tenant_domains = entra_id_recon.get_tenant_domains
get_credential_type_info = entra_id_recon.get_credential_type_info
check_sharepoint = entra_id_recon.check_sharepoint
check_azure_services = entra_id_recon.check_azure_services
class TestDNSResolution(unittest.TestCase):
"""Test DNS resolution functionality."""
def test_resolve_dns_valid_domain(self):
"""Test DNS resolution for a valid domain."""
# This test requires network access
results = resolve_dns("google.com", "A")
self.assertIsInstance(results, list)
def test_resolve_dns_invalid_domain(self):
"""Test DNS resolution for an invalid domain."""
results = resolve_dns("nonexistent-domain-12345.invalid", "A")
self.assertEqual(results, [])
class TestTenantInformation(unittest.TestCase):
"""Test tenant information retrieval."""
def test_get_tenant_id_invalid_domain(self):
"""Test tenant ID retrieval for invalid domain."""
tenant_id, region = get_tenant_id("nonexistent-domain-12345.invalid")
self.assertIsNone(tenant_id)
self.assertIsNone(region)
def test_get_tenant_brand_invalid_domain(self):
"""Test tenant brand retrieval for invalid domain."""
brand, sso = get_tenant_brand_and_sso("nonexistent-domain-12345.invalid")
self.assertIsNone(brand)
self.assertIsNone(sso)
class TestDeviceCodeAuth(unittest.TestCase):
"""Test device code authentication detection."""
def test_check_device_code_auth_no_tenant_id(self):
"""Test device code auth check with no tenant ID."""
result = check_device_code_auth(None)
self.assertFalse(result)
def test_check_device_code_auth_invalid_tenant_id(self):
"""Test device code auth check with invalid tenant ID."""
result = check_device_code_auth("00000000-0000-0000-0000-000000000000")
# Should return False for invalid tenant ID
self.assertIsInstance(result, bool)
class TestAzureServices(unittest.TestCase):
"""Test Azure services detection."""
def test_check_azure_services_no_tenant_id(self):
"""Test Azure services check with no tenant ID."""
results = check_azure_services("example.com", None)
self.assertIsInstance(results, dict)
self.assertEqual(results, {})
def test_check_sharepoint_invalid_domain(self):
"""Test SharePoint detection for invalid domain."""
result = check_sharepoint("nonexistent-domain-12345.invalid")
self.assertFalse(result)
class TestUserEnumeration(unittest.TestCase):
"""Test user enumeration functionality."""
def test_get_credential_type_info_invalid_username(self):
"""Test credential type info for invalid username."""
result = get_credential_type_info("nonexistent-user-12345@invalid-domain.invalid")
# Should return None or a valid response structure
self.assertTrue(result is None or isinstance(result, dict))
class TestDomainRetrieval(unittest.TestCase):
"""Test domain retrieval functionality."""
def test_get_tenant_domains_invalid_domain(self):
"""Test tenant domains retrieval for invalid domain."""
result = get_tenant_domains("nonexistent-domain-12345.invalid")
self.assertIsNone(result)
if __name__ == '__main__':
unittest.main()