#!/usr/bin/env python3 """ Unit tests for entra-id-recon.py These tests verify the core functionality of the reconnaissance and enumeration modules. Note: Some tests may require network access and may interact with Microsoft services. """ import unittest import sys import os # Add parent directory to path to import the main module sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) # Import the module - note the filename uses hyphens, so we need to import differently import importlib.util spec = importlib.util.spec_from_file_location("entra_id_recon", os.path.join(os.path.dirname(__file__), '..', 'entra-id-recon.py')) entra_id_recon = importlib.util.module_from_spec(spec) spec.loader.exec_module(entra_id_recon) # Import functions from the loaded module resolve_dns = entra_id_recon.resolve_dns get_tenant_id = entra_id_recon.get_tenant_id get_tenant_brand_and_sso = entra_id_recon.get_tenant_brand_and_sso check_device_code_auth = entra_id_recon.check_device_code_auth get_tenant_domains = entra_id_recon.get_tenant_domains get_credential_type_info = entra_id_recon.get_credential_type_info check_sharepoint = entra_id_recon.check_sharepoint check_azure_services = entra_id_recon.check_azure_services class TestDNSResolution(unittest.TestCase): """Test DNS resolution functionality.""" def test_resolve_dns_valid_domain(self): """Test DNS resolution for a valid domain.""" # This test requires network access results = resolve_dns("google.com", "A") self.assertIsInstance(results, list) def test_resolve_dns_invalid_domain(self): """Test DNS resolution for an invalid domain.""" results = resolve_dns("nonexistent-domain-12345.invalid", "A") self.assertEqual(results, []) class TestTenantInformation(unittest.TestCase): """Test tenant information retrieval.""" def test_get_tenant_id_invalid_domain(self): """Test tenant ID retrieval for invalid domain.""" tenant_id, region = get_tenant_id("nonexistent-domain-12345.invalid") self.assertIsNone(tenant_id) self.assertIsNone(region) def test_get_tenant_brand_invalid_domain(self): """Test tenant brand retrieval for invalid domain.""" brand, sso = get_tenant_brand_and_sso("nonexistent-domain-12345.invalid") self.assertIsNone(brand) self.assertIsNone(sso) class TestDeviceCodeAuth(unittest.TestCase): """Test device code authentication detection.""" def test_check_device_code_auth_no_tenant_id(self): """Test device code auth check with no tenant ID.""" result = check_device_code_auth(None) self.assertFalse(result) def test_check_device_code_auth_invalid_tenant_id(self): """Test device code auth check with invalid tenant ID.""" result = check_device_code_auth("00000000-0000-0000-0000-000000000000") # Should return False for invalid tenant ID self.assertIsInstance(result, bool) class TestAzureServices(unittest.TestCase): """Test Azure services detection.""" def test_check_azure_services_no_tenant_id(self): """Test Azure services check with no tenant ID.""" results = check_azure_services("example.com", None) self.assertIsInstance(results, dict) self.assertEqual(results, {}) def test_check_sharepoint_invalid_domain(self): """Test SharePoint detection for invalid domain.""" result = check_sharepoint("nonexistent-domain-12345.invalid") self.assertFalse(result) class TestUserEnumeration(unittest.TestCase): """Test user enumeration functionality.""" def test_get_credential_type_info_invalid_username(self): """Test credential type info for invalid username.""" result = get_credential_type_info("nonexistent-user-12345@invalid-domain.invalid") # Should return None or a valid response structure self.assertTrue(result is None or isinstance(result, dict)) class TestDomainRetrieval(unittest.TestCase): """Test domain retrieval functionality.""" def test_get_tenant_domains_invalid_domain(self): """Test tenant domains retrieval for invalid domain.""" result = get_tenant_domains("nonexistent-domain-12345.invalid") self.assertIsNone(result) if __name__ == '__main__': unittest.main()