""" Site service for managing reusable site definitions. This service handles the business logic for creating, updating, and managing sites with their associated CIDR ranges and IP-level overrides. """ import ipaddress import json import logging from datetime import datetime from typing import Any, Dict, List, Optional from sqlalchemy import func from sqlalchemy.orm import Session, joinedload from web.models import ( Site, SiteIP, ScanSiteAssociation ) from web.utils.pagination import paginate, PaginatedResult logger = logging.getLogger(__name__) class SiteService: """ Service for managing reusable site definitions. Handles site lifecycle: creation, updates, deletion (with safety checks), CIDR management, and IP-level overrides. """ def __init__(self, db_session: Session): """ Initialize site service. Args: db_session: SQLAlchemy database session """ self.db = db_session def create_site(self, name: str, description: Optional[str] = None) -> Dict[str, Any]: """ Create a new site. Args: name: Unique site name description: Optional site description Returns: Dictionary with created site data Raises: ValueError: If site name already exists """ # Validate site name is unique existing = self.db.query(Site).filter(Site.name == name).first() if existing: raise ValueError(f"Site with name '{name}' already exists") # Create site (can be empty, IPs added separately) site = Site( name=name, description=description, created_at=datetime.utcnow(), updated_at=datetime.utcnow() ) self.db.add(site) self.db.commit() self.db.refresh(site) logger.info(f"Created site '{name}' (id={site.id})") return self._site_to_dict(site) def update_site(self, site_id: int, name: Optional[str] = None, description: Optional[str] = None) -> Dict[str, Any]: """ Update site metadata (name and/or description). Args: site_id: Site ID to update name: New site name (must be unique) description: New description Returns: Dictionary with updated site data Raises: ValueError: If site not found or name already exists """ site = self.db.query(Site).filter(Site.id == site_id).first() if not site: raise ValueError(f"Site with id {site_id} not found") # Update name if provided if name is not None and name != site.name: # Check uniqueness existing = self.db.query(Site).filter( Site.name == name, Site.id != site_id ).first() if existing: raise ValueError(f"Site with name '{name}' already exists") site.name = name # Update description if provided if description is not None: site.description = description site.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(site) logger.info(f"Updated site {site_id} ('{site.name}')") return self._site_to_dict(site) def delete_site(self, site_id: int) -> None: """ Delete a site. Prevents deletion if the site is used in any scan (per user requirement). Args: site_id: Site ID to delete Raises: ValueError: If site not found or is used in scans """ site = self.db.query(Site).filter(Site.id == site_id).first() if not site: raise ValueError(f"Site with id {site_id} not found") # Check if site is used in any scans usage_count = ( self.db.query(func.count(ScanSiteAssociation.id)) .filter(ScanSiteAssociation.site_id == site_id) .scalar() ) if usage_count > 0: raise ValueError( f"Cannot delete site '{site.name}': it is used in {usage_count} scan(s). " f"Sites that have been used in scans cannot be deleted." ) # Safe to delete self.db.delete(site) self.db.commit() logger.info(f"Deleted site {site_id} ('{site.name}')") def get_site(self, site_id: int) -> Optional[Dict[str, Any]]: """ Get site details. Args: site_id: Site ID to retrieve Returns: Dictionary with site data, or None if not found """ site = ( self.db.query(Site) .filter(Site.id == site_id) .first() ) if not site: return None return self._site_to_dict(site) def get_site_by_name(self, name: str) -> Optional[Dict[str, Any]]: """ Get site details by name. Args: name: Site name to retrieve Returns: Dictionary with site data, or None if not found """ site = ( self.db.query(Site) .filter(Site.name == name) .first() ) if not site: return None return self._site_to_dict(site) def list_sites(self, page: int = 1, per_page: int = 20) -> PaginatedResult: """ List all sites with pagination. Args: page: Page number (1-indexed) per_page: Number of items per page Returns: PaginatedResult with site data """ query = ( self.db.query(Site) .order_by(Site.name) ) return paginate(query, page, per_page, self._site_to_dict) def list_all_sites(self) -> List[Dict[str, Any]]: """ List all sites without pagination (for dropdowns, etc.). Returns: List of site dictionaries """ sites = ( self.db.query(Site) .order_by(Site.name) .all() ) return [self._site_to_dict(site) for site in sites] def bulk_add_ips_from_cidr(self, site_id: int, cidr: str, expected_ping: Optional[bool] = None, expected_tcp_ports: Optional[List[int]] = None, expected_udp_ports: Optional[List[int]] = None) -> Dict[str, Any]: """ Expand a CIDR range and add all IPs to a site. CIDRs are NOT stored - they are just used to generate IP records. Args: site_id: Site ID cidr: CIDR notation (e.g., "10.0.0.0/24") expected_ping: Expected ping response for all IPs expected_tcp_ports: List of expected TCP ports for all IPs expected_udp_ports: List of expected UDP ports for all IPs Returns: Dictionary with: - cidr: The CIDR that was expanded - ip_count: Number of IPs created - ips_added: List of IP addresses created - ips_skipped: List of IPs that already existed Raises: ValueError: If site not found or CIDR is invalid/too large """ site = self.db.query(Site).filter(Site.id == site_id).first() if not site: raise ValueError(f"Site with id {site_id} not found") # Validate CIDR format and size try: network = ipaddress.ip_network(cidr, strict=False) except ValueError as e: raise ValueError(f"Invalid CIDR notation '{cidr}': {str(e)}") # Enforce CIDR size limits (max /24 for IPv4, /64 for IPv6) if isinstance(network, ipaddress.IPv4Network) and network.prefixlen < 24: raise ValueError( f"CIDR '{cidr}' is too large ({network.num_addresses} IPs). " f"Maximum allowed is /24 (256 IPs) for IPv4." ) elif isinstance(network, ipaddress.IPv6Network) and network.prefixlen < 64: raise ValueError( f"CIDR '{cidr}' is too large. " f"Maximum allowed is /64 for IPv6." ) # Expand CIDR to individual IPs (no cidr_id since we're not storing CIDR) ip_count, ips_added, ips_skipped = self._expand_cidr_to_ips( site_id=site_id, network=network, expected_ping=expected_ping, expected_tcp_ports=expected_tcp_ports or [], expected_udp_ports=expected_udp_ports or [] ) site.updated_at = datetime.utcnow() self.db.commit() logger.info( f"Expanded CIDR '{cidr}' for site {site_id} ('{site.name}'): " f"added {ip_count} IPs, skipped {len(ips_skipped)} duplicates" ) return { 'cidr': cidr, 'ip_count': ip_count, 'ips_added': ips_added, 'ips_skipped': ips_skipped } def bulk_add_ips_from_list(self, site_id: int, ip_list: List[str], expected_ping: Optional[bool] = None, expected_tcp_ports: Optional[List[int]] = None, expected_udp_ports: Optional[List[int]] = None) -> Dict[str, Any]: """ Add multiple IPs from a list (e.g., from CSV/text import). Args: site_id: Site ID ip_list: List of IP addresses as strings expected_ping: Expected ping response for all IPs expected_tcp_ports: List of expected TCP ports for all IPs expected_udp_ports: List of expected UDP ports for all IPs Returns: Dictionary with: - ip_count: Number of IPs successfully created - ips_added: List of IP addresses created - ips_skipped: List of IPs that already existed - errors: List of validation errors {ip: error_message} Raises: ValueError: If site not found """ site = self.db.query(Site).filter(Site.id == site_id).first() if not site: raise ValueError(f"Site with id {site_id} not found") ips_added = [] ips_skipped = [] errors = [] for ip_str in ip_list: ip_str = ip_str.strip() if not ip_str: continue # Skip empty lines # Validate IP format try: ipaddress.ip_address(ip_str) except ValueError as e: errors.append({'ip': ip_str, 'error': f"Invalid IP address: {str(e)}"}) continue # Check for duplicate (across all IPs in the site) existing = ( self.db.query(SiteIP) .filter(SiteIP.site_id == site_id, SiteIP.ip_address == ip_str) .first() ) if existing: ips_skipped.append(ip_str) continue # Create IP record try: ip_obj = SiteIP( site_id=site_id, ip_address=ip_str, expected_ping=expected_ping, expected_tcp_ports=json.dumps(expected_tcp_ports or []), expected_udp_ports=json.dumps(expected_udp_ports or []), created_at=datetime.utcnow() ) self.db.add(ip_obj) ips_added.append(ip_str) except Exception as e: errors.append({'ip': ip_str, 'error': f"Database error: {str(e)}"}) site.updated_at = datetime.utcnow() self.db.commit() logger.info( f"Bulk added {len(ips_added)} IPs to site {site_id} ('{site.name}'), " f"skipped {len(ips_skipped)} duplicates, {len(errors)} errors" ) return { 'ip_count': len(ips_added), 'ips_added': ips_added, 'ips_skipped': ips_skipped, 'errors': errors } def add_standalone_ip(self, site_id: int, ip_address: str, expected_ping: Optional[bool] = None, expected_tcp_ports: Optional[List[int]] = None, expected_udp_ports: Optional[List[int]] = None) -> Dict[str, Any]: """ Add a standalone IP (without a CIDR parent) to a site. Args: site_id: Site ID ip_address: IP address to add expected_ping: Expected ping response expected_tcp_ports: List of expected TCP ports expected_udp_ports: List of expected UDP ports Returns: Dictionary with IP data Raises: ValueError: If site not found, IP is invalid, or already exists """ site = self.db.query(Site).filter(Site.id == site_id).first() if not site: raise ValueError(f"Site with id {site_id} not found") # Validate IP format try: ipaddress.ip_address(ip_address) except ValueError as e: raise ValueError(f"Invalid IP address '{ip_address}': {str(e)}") # Check for duplicate (across all IPs in the site) existing = ( self.db.query(SiteIP) .filter(SiteIP.site_id == site_id, SiteIP.ip_address == ip_address) .first() ) if existing: raise ValueError(f"IP '{ip_address}' already exists in this site") # Create IP ip_obj = SiteIP( site_id=site_id, ip_address=ip_address, expected_ping=expected_ping, expected_tcp_ports=json.dumps(expected_tcp_ports or []), expected_udp_ports=json.dumps(expected_udp_ports or []), created_at=datetime.utcnow() ) self.db.add(ip_obj) site.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(ip_obj) logger.info(f"Added IP '{ip_address}' to site {site_id} ('{site.name}')") return self._ip_to_dict(ip_obj) def update_ip_settings(self, site_id: int, ip_id: int, expected_ping: Optional[bool] = None, expected_tcp_ports: Optional[List[int]] = None, expected_udp_ports: Optional[List[int]] = None) -> Dict[str, Any]: """ Update settings for an individual IP. Args: site_id: Site ID ip_id: IP ID to update expected_ping: New ping expectation (if provided) expected_tcp_ports: New TCP ports expectation (if provided) expected_udp_ports: New UDP ports expectation (if provided) Returns: Dictionary with updated IP data Raises: ValueError: If IP not found """ ip_obj = ( self.db.query(SiteIP) .filter(SiteIP.id == ip_id, SiteIP.site_id == site_id) .first() ) if not ip_obj: raise ValueError(f"IP with id {ip_id} not found for site {site_id}") # Update settings if provided if expected_ping is not None: ip_obj.expected_ping = expected_ping if expected_tcp_ports is not None: ip_obj.expected_tcp_ports = json.dumps(expected_tcp_ports) if expected_udp_ports is not None: ip_obj.expected_udp_ports = json.dumps(expected_udp_ports) self.db.commit() self.db.refresh(ip_obj) logger.info(f"Updated settings for IP '{ip_obj.ip_address}' in site {site_id}") return self._ip_to_dict(ip_obj) def remove_ip(self, site_id: int, ip_id: int) -> None: """ Remove an IP from a site. Args: site_id: Site ID ip_id: IP ID to remove Raises: ValueError: If IP not found """ ip_obj = ( self.db.query(SiteIP) .filter(SiteIP.id == ip_id, SiteIP.site_id == site_id) .first() ) if not ip_obj: raise ValueError(f"IP with id {ip_id} not found for site {site_id}") ip_address = ip_obj.ip_address self.db.delete(ip_obj) self.db.commit() logger.info(f"Removed IP '{ip_address}' from site {site_id}") def list_ips(self, site_id: int, page: int = 1, per_page: int = 50) -> PaginatedResult: """ List IPs in a site with pagination. Args: site_id: Site ID page: Page number (1-indexed) per_page: Number of items per page Returns: PaginatedResult with IP data """ query = ( self.db.query(SiteIP) .filter(SiteIP.site_id == site_id) .order_by(SiteIP.ip_address) ) return paginate(query, page, per_page, self._ip_to_dict) def get_scan_usage(self, site_id: int) -> List[Dict[str, Any]]: """ Get list of scans that use this site. Args: site_id: Site ID Returns: List of scan dictionaries """ from web.models import Scan # Import here to avoid circular dependency associations = ( self.db.query(ScanSiteAssociation) .options(joinedload(ScanSiteAssociation.scan)) .filter(ScanSiteAssociation.site_id == site_id) .all() ) return [ { 'id': assoc.scan.id, 'title': assoc.scan.title, 'timestamp': assoc.scan.timestamp.isoformat() if assoc.scan.timestamp else None, 'status': assoc.scan.status } for assoc in associations ] # Private helper methods def _expand_cidr_to_ips(self, site_id: int, network: ipaddress.IPv4Network | ipaddress.IPv6Network, expected_ping: Optional[bool], expected_tcp_ports: List[int], expected_udp_ports: List[int]) -> tuple[int, List[str], List[str]]: """ Expand a CIDR to individual IP addresses. Args: site_id: Site ID network: ipaddress network object expected_ping: Default ping setting for all IPs expected_tcp_ports: Default TCP ports for all IPs expected_udp_ports: Default UDP ports for all IPs Returns: Tuple of (count of IPs created, list of IPs added, list of IPs skipped) """ ip_count = 0 ips_added = [] ips_skipped = [] # For /32 or /128 (single host), use the network address # For larger ranges, use hosts() to exclude network/broadcast addresses if network.num_addresses == 1: ip_list = [network.network_address] elif network.num_addresses == 2: # For /31 networks (point-to-point), both addresses are usable ip_list = [network.network_address, network.broadcast_address] else: # Use hosts() to get usable IPs (excludes network and broadcast) ip_list = list(network.hosts()) for ip in ip_list: ip_str = str(ip) # Check for duplicate existing = ( self.db.query(SiteIP) .filter(SiteIP.site_id == site_id, SiteIP.ip_address == ip_str) .first() ) if existing: ips_skipped.append(ip_str) continue # Create SiteIP entry ip_obj = SiteIP( site_id=site_id, ip_address=ip_str, expected_ping=expected_ping, expected_tcp_ports=json.dumps(expected_tcp_ports), expected_udp_ports=json.dumps(expected_udp_ports), created_at=datetime.utcnow() ) self.db.add(ip_obj) ips_added.append(ip_str) ip_count += 1 return ip_count, ips_added, ips_skipped def _site_to_dict(self, site: Site) -> Dict[str, Any]: """Convert Site model to dictionary.""" # Count IPs for this site ip_count = ( self.db.query(func.count(SiteIP.id)) .filter(SiteIP.site_id == site.id) .scalar() or 0 ) return { 'id': site.id, 'name': site.name, 'description': site.description, 'created_at': site.created_at.isoformat() if site.created_at else None, 'updated_at': site.updated_at.isoformat() if site.updated_at else None, 'ip_count': ip_count } def _ip_to_dict(self, ip: SiteIP) -> Dict[str, Any]: """Convert SiteIP model to dictionary.""" return { 'id': ip.id, 'site_id': ip.site_id, 'ip_address': ip.ip_address, 'expected_ping': ip.expected_ping, 'expected_tcp_ports': json.loads(ip.expected_tcp_ports) if ip.expected_tcp_ports else [], 'expected_udp_ports': json.loads(ip.expected_udp_ports) if ip.expected_udp_ports else [], 'created_at': ip.created_at.isoformat() if ip.created_at else None }