""" Site service for managing reusable site definitions. This service handles the business logic for creating, updating, and managing sites with their associated CIDR ranges and IP-level overrides. """ import ipaddress import json import logging from datetime import datetime from typing import Any, Dict, List, Optional from sqlalchemy import func from sqlalchemy.orm import Session, joinedload from sqlalchemy.exc import IntegrityError from web.models import ( Site, SiteCIDR, SiteIP, ScanSiteAssociation ) from web.utils.pagination import paginate, PaginatedResult logger = logging.getLogger(__name__) class SiteService: """ Service for managing reusable site definitions. Handles site lifecycle: creation, updates, deletion (with safety checks), CIDR management, and IP-level overrides. """ def __init__(self, db_session: Session): """ Initialize site service. Args: db_session: SQLAlchemy database session """ self.db = db_session def create_site(self, name: str, description: Optional[str] = None, cidrs: Optional[List[Dict[str, Any]]] = None) -> Dict[str, Any]: """ Create a new site with optional CIDR ranges. Args: name: Unique site name description: Optional site description cidrs: List of CIDR definitions with format: [{"cidr": "10.0.0.0/24", "expected_ping": true, "expected_tcp_ports": [22, 80], "expected_udp_ports": [53]}] Returns: Dictionary with created site data Raises: ValueError: If site name already exists or validation fails """ # Validate site name is unique existing = self.db.query(Site).filter(Site.name == name).first() if existing: raise ValueError(f"Site with name '{name}' already exists") # Validate we have at least one CIDR if provided if cidrs is not None and len(cidrs) == 0: raise ValueError("Site must have at least one CIDR range") # Create site site = Site( name=name, description=description, created_at=datetime.utcnow(), updated_at=datetime.utcnow() ) self.db.add(site) self.db.flush() # Get site.id without committing # Add CIDRs if provided if cidrs: for cidr_data in cidrs: self._add_cidr_to_site(site, cidr_data) self.db.commit() self.db.refresh(site) logger.info(f"Created site '{name}' (id={site.id}) with {len(cidrs or [])} CIDR(s)") return self._site_to_dict(site) def update_site(self, site_id: int, name: Optional[str] = None, description: Optional[str] = None) -> Dict[str, Any]: """ Update site metadata (name and/or description). Args: site_id: Site ID to update name: New site name (must be unique) description: New description Returns: Dictionary with updated site data Raises: ValueError: If site not found or name already exists """ site = self.db.query(Site).filter(Site.id == site_id).first() if not site: raise ValueError(f"Site with id {site_id} not found") # Update name if provided if name is not None and name != site.name: # Check uniqueness existing = self.db.query(Site).filter( Site.name == name, Site.id != site_id ).first() if existing: raise ValueError(f"Site with name '{name}' already exists") site.name = name # Update description if provided if description is not None: site.description = description site.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(site) logger.info(f"Updated site {site_id} ('{site.name}')") return self._site_to_dict(site) def delete_site(self, site_id: int) -> None: """ Delete a site. Prevents deletion if the site is used in any scan (per user requirement). Args: site_id: Site ID to delete Raises: ValueError: If site not found or is used in scans """ site = self.db.query(Site).filter(Site.id == site_id).first() if not site: raise ValueError(f"Site with id {site_id} not found") # Check if site is used in any scans usage_count = ( self.db.query(func.count(ScanSiteAssociation.id)) .filter(ScanSiteAssociation.site_id == site_id) .scalar() ) if usage_count > 0: raise ValueError( f"Cannot delete site '{site.name}': it is used in {usage_count} scan(s). " f"Sites that have been used in scans cannot be deleted." ) # Safe to delete self.db.delete(site) self.db.commit() logger.info(f"Deleted site {site_id} ('{site.name}')") def get_site(self, site_id: int) -> Optional[Dict[str, Any]]: """ Get site details with all CIDRs and IP overrides. Args: site_id: Site ID to retrieve Returns: Dictionary with site data, or None if not found """ site = ( self.db.query(Site) .options( joinedload(Site.cidrs).joinedload(SiteCIDR.ips) ) .filter(Site.id == site_id) .first() ) if not site: return None return self._site_to_dict(site) def get_site_by_name(self, name: str) -> Optional[Dict[str, Any]]: """ Get site details by name. Args: name: Site name to retrieve Returns: Dictionary with site data, or None if not found """ site = ( self.db.query(Site) .options( joinedload(Site.cidrs).joinedload(SiteCIDR.ips) ) .filter(Site.name == name) .first() ) if not site: return None return self._site_to_dict(site) def list_sites(self, page: int = 1, per_page: int = 20) -> PaginatedResult: """ List all sites with pagination. Args: page: Page number (1-indexed) per_page: Number of items per page Returns: PaginatedResult with site data """ query = ( self.db.query(Site) .options(joinedload(Site.cidrs)) .order_by(Site.name) ) return paginate(query, page, per_page, self._site_to_dict) def list_all_sites(self) -> List[Dict[str, Any]]: """ List all sites without pagination (for dropdowns, etc.). Returns: List of site dictionaries """ sites = ( self.db.query(Site) .options(joinedload(Site.cidrs)) .order_by(Site.name) .all() ) return [self._site_to_dict(site) for site in sites] def add_cidr(self, site_id: int, cidr: str, expected_ping: Optional[bool] = None, expected_tcp_ports: Optional[List[int]] = None, expected_udp_ports: Optional[List[int]] = None) -> Dict[str, Any]: """ Add a CIDR range to a site. Args: site_id: Site ID cidr: CIDR notation (e.g., "10.0.0.0/24") expected_ping: Expected ping response for IPs in this CIDR expected_tcp_ports: List of expected TCP ports expected_udp_ports: List of expected UDP ports Returns: Dictionary with CIDR data Raises: ValueError: If site not found, CIDR is invalid, or already exists """ site = self.db.query(Site).filter(Site.id == site_id).first() if not site: raise ValueError(f"Site with id {site_id} not found") # Validate CIDR format try: ipaddress.ip_network(cidr, strict=False) except ValueError as e: raise ValueError(f"Invalid CIDR notation '{cidr}': {str(e)}") # Check for duplicate CIDR existing = ( self.db.query(SiteCIDR) .filter(SiteCIDR.site_id == site_id, SiteCIDR.cidr == cidr) .first() ) if existing: raise ValueError(f"CIDR '{cidr}' already exists for this site") # Create CIDR cidr_obj = SiteCIDR( site_id=site_id, cidr=cidr, expected_ping=expected_ping, expected_tcp_ports=json.dumps(expected_tcp_ports or []), expected_udp_ports=json.dumps(expected_udp_ports or []), created_at=datetime.utcnow() ) self.db.add(cidr_obj) site.updated_at = datetime.utcnow() self.db.commit() self.db.refresh(cidr_obj) logger.info(f"Added CIDR '{cidr}' to site {site_id} ('{site.name}')") return self._cidr_to_dict(cidr_obj) def remove_cidr(self, site_id: int, cidr_id: int) -> None: """ Remove a CIDR range from a site. Prevents removal if it's the last CIDR (sites must have at least one CIDR). Args: site_id: Site ID cidr_id: CIDR ID to remove Raises: ValueError: If CIDR not found or it's the last CIDR """ site = self.db.query(Site).filter(Site.id == site_id).first() if not site: raise ValueError(f"Site with id {site_id} not found") cidr = ( self.db.query(SiteCIDR) .filter(SiteCIDR.id == cidr_id, SiteCIDR.site_id == site_id) .first() ) if not cidr: raise ValueError(f"CIDR with id {cidr_id} not found for site {site_id}") # Check if this is the last CIDR cidr_count = ( self.db.query(func.count(SiteCIDR.id)) .filter(SiteCIDR.site_id == site_id) .scalar() ) if cidr_count <= 1: raise ValueError( f"Cannot remove CIDR '{cidr.cidr}': site must have at least one CIDR range" ) self.db.delete(cidr) site.updated_at = datetime.utcnow() self.db.commit() logger.info(f"Removed CIDR '{cidr.cidr}' from site {site_id} ('{site.name}')") def add_ip_override(self, cidr_id: int, ip_address: str, expected_ping: Optional[bool] = None, expected_tcp_ports: Optional[List[int]] = None, expected_udp_ports: Optional[List[int]] = None) -> Dict[str, Any]: """ Add an IP-level expectation override within a CIDR. Args: cidr_id: CIDR ID ip_address: IP address to override expected_ping: Override ping expectation expected_tcp_ports: Override TCP ports expectation expected_udp_ports: Override UDP ports expectation Returns: Dictionary with IP override data Raises: ValueError: If CIDR not found, IP is invalid, or not in CIDR range """ cidr = self.db.query(SiteCIDR).filter(SiteCIDR.id == cidr_id).first() if not cidr: raise ValueError(f"CIDR with id {cidr_id} not found") # Validate IP format try: ip_obj = ipaddress.ip_address(ip_address) except ValueError as e: raise ValueError(f"Invalid IP address '{ip_address}': {str(e)}") # Validate IP is within CIDR range network = ipaddress.ip_network(cidr.cidr, strict=False) if ip_obj not in network: raise ValueError(f"IP address '{ip_address}' is not within CIDR '{cidr.cidr}'") # Check for duplicate existing = ( self.db.query(SiteIP) .filter(SiteIP.site_cidr_id == cidr_id, SiteIP.ip_address == ip_address) .first() ) if existing: raise ValueError(f"IP override for '{ip_address}' already exists in this CIDR") # Create IP override ip_override = SiteIP( site_cidr_id=cidr_id, ip_address=ip_address, expected_ping=expected_ping, expected_tcp_ports=json.dumps(expected_tcp_ports or []), expected_udp_ports=json.dumps(expected_udp_ports or []), created_at=datetime.utcnow() ) self.db.add(ip_override) self.db.commit() self.db.refresh(ip_override) logger.info(f"Added IP override '{ip_address}' to CIDR {cidr_id} ('{cidr.cidr}')") return self._ip_override_to_dict(ip_override) def remove_ip_override(self, cidr_id: int, ip_id: int) -> None: """ Remove an IP-level override. Args: cidr_id: CIDR ID ip_id: IP override ID to remove Raises: ValueError: If IP override not found """ ip_override = ( self.db.query(SiteIP) .filter(SiteIP.id == ip_id, SiteIP.site_cidr_id == cidr_id) .first() ) if not ip_override: raise ValueError(f"IP override with id {ip_id} not found for CIDR {cidr_id}") ip_address = ip_override.ip_address self.db.delete(ip_override) self.db.commit() logger.info(f"Removed IP override '{ip_address}' from CIDR {cidr_id}") def get_scan_usage(self, site_id: int) -> List[Dict[str, Any]]: """ Get list of scans that use this site. Args: site_id: Site ID Returns: List of scan dictionaries """ from web.models import Scan # Import here to avoid circular dependency associations = ( self.db.query(ScanSiteAssociation) .options(joinedload(ScanSiteAssociation.scan)) .filter(ScanSiteAssociation.site_id == site_id) .all() ) return [ { 'id': assoc.scan.id, 'title': assoc.scan.title, 'timestamp': assoc.scan.timestamp.isoformat() if assoc.scan.timestamp else None, 'status': assoc.scan.status } for assoc in associations ] # Private helper methods def _add_cidr_to_site(self, site: Site, cidr_data: Dict[str, Any]) -> SiteCIDR: """Helper to add CIDR during site creation.""" cidr = cidr_data.get('cidr') if not cidr: raise ValueError("CIDR 'cidr' field is required") # Validate CIDR format try: ipaddress.ip_network(cidr, strict=False) except ValueError as e: raise ValueError(f"Invalid CIDR notation '{cidr}': {str(e)}") cidr_obj = SiteCIDR( site_id=site.id, cidr=cidr, expected_ping=cidr_data.get('expected_ping'), expected_tcp_ports=json.dumps(cidr_data.get('expected_tcp_ports', [])), expected_udp_ports=json.dumps(cidr_data.get('expected_udp_ports', [])), created_at=datetime.utcnow() ) self.db.add(cidr_obj) return cidr_obj def _site_to_dict(self, site: Site) -> Dict[str, Any]: """Convert Site model to dictionary.""" return { 'id': site.id, 'name': site.name, 'description': site.description, 'created_at': site.created_at.isoformat() if site.created_at else None, 'updated_at': site.updated_at.isoformat() if site.updated_at else None, 'cidrs': [self._cidr_to_dict(cidr) for cidr in site.cidrs] if hasattr(site, 'cidrs') else [] } def _cidr_to_dict(self, cidr: SiteCIDR) -> Dict[str, Any]: """Convert SiteCIDR model to dictionary.""" return { 'id': cidr.id, 'site_id': cidr.site_id, 'cidr': cidr.cidr, 'expected_ping': cidr.expected_ping, 'expected_tcp_ports': json.loads(cidr.expected_tcp_ports) if cidr.expected_tcp_ports else [], 'expected_udp_ports': json.loads(cidr.expected_udp_ports) if cidr.expected_udp_ports else [], 'created_at': cidr.created_at.isoformat() if cidr.created_at else None, 'ip_overrides': [self._ip_override_to_dict(ip) for ip in cidr.ips] if hasattr(cidr, 'ips') else [] } def _ip_override_to_dict(self, ip: SiteIP) -> Dict[str, Any]: """Convert SiteIP model to dictionary.""" return { 'id': ip.id, 'site_cidr_id': ip.site_cidr_id, 'ip_address': ip.ip_address, 'expected_ping': ip.expected_ping, 'expected_tcp_ports': json.loads(ip.expected_tcp_ports) if ip.expected_tcp_ports else [], 'expected_udp_ports': json.loads(ip.expected_udp_ports) if ip.expected_udp_ports else [], 'created_at': ip.created_at.isoformat() if ip.created_at else None }