Files
SneakyScan/app/web/services/site_service.py

532 lines
17 KiB
Python

"""
Site service for managing reusable site definitions.
This service handles the business logic for creating, updating, and managing
sites with their associated CIDR ranges and IP-level overrides.
"""
import ipaddress
import json
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from sqlalchemy import func
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.exc import IntegrityError
from web.models import (
Site, SiteCIDR, SiteIP, ScanSiteAssociation
)
from web.utils.pagination import paginate, PaginatedResult
logger = logging.getLogger(__name__)
class SiteService:
"""
Service for managing reusable site definitions.
Handles site lifecycle: creation, updates, deletion (with safety checks),
CIDR management, and IP-level overrides.
"""
def __init__(self, db_session: Session):
"""
Initialize site service.
Args:
db_session: SQLAlchemy database session
"""
self.db = db_session
def create_site(self, name: str, description: Optional[str] = None,
cidrs: Optional[List[Dict[str, Any]]] = None) -> Dict[str, Any]:
"""
Create a new site with optional CIDR ranges.
Args:
name: Unique site name
description: Optional site description
cidrs: List of CIDR definitions with format:
[{"cidr": "10.0.0.0/24", "expected_ping": true,
"expected_tcp_ports": [22, 80], "expected_udp_ports": [53]}]
Returns:
Dictionary with created site data
Raises:
ValueError: If site name already exists or validation fails
"""
# Validate site name is unique
existing = self.db.query(Site).filter(Site.name == name).first()
if existing:
raise ValueError(f"Site with name '{name}' already exists")
# Validate we have at least one CIDR if provided
if cidrs is not None and len(cidrs) == 0:
raise ValueError("Site must have at least one CIDR range")
# Create site
site = Site(
name=name,
description=description,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
self.db.add(site)
self.db.flush() # Get site.id without committing
# Add CIDRs if provided
if cidrs:
for cidr_data in cidrs:
self._add_cidr_to_site(site, cidr_data)
self.db.commit()
self.db.refresh(site)
logger.info(f"Created site '{name}' (id={site.id}) with {len(cidrs or [])} CIDR(s)")
return self._site_to_dict(site)
def update_site(self, site_id: int, name: Optional[str] = None,
description: Optional[str] = None) -> Dict[str, Any]:
"""
Update site metadata (name and/or description).
Args:
site_id: Site ID to update
name: New site name (must be unique)
description: New description
Returns:
Dictionary with updated site data
Raises:
ValueError: If site not found or name already exists
"""
site = self.db.query(Site).filter(Site.id == site_id).first()
if not site:
raise ValueError(f"Site with id {site_id} not found")
# Update name if provided
if name is not None and name != site.name:
# Check uniqueness
existing = self.db.query(Site).filter(
Site.name == name,
Site.id != site_id
).first()
if existing:
raise ValueError(f"Site with name '{name}' already exists")
site.name = name
# Update description if provided
if description is not None:
site.description = description
site.updated_at = datetime.utcnow()
self.db.commit()
self.db.refresh(site)
logger.info(f"Updated site {site_id} ('{site.name}')")
return self._site_to_dict(site)
def delete_site(self, site_id: int) -> None:
"""
Delete a site.
Prevents deletion if the site is used in any scan (per user requirement).
Args:
site_id: Site ID to delete
Raises:
ValueError: If site not found or is used in scans
"""
site = self.db.query(Site).filter(Site.id == site_id).first()
if not site:
raise ValueError(f"Site with id {site_id} not found")
# Check if site is used in any scans
usage_count = (
self.db.query(func.count(ScanSiteAssociation.id))
.filter(ScanSiteAssociation.site_id == site_id)
.scalar()
)
if usage_count > 0:
raise ValueError(
f"Cannot delete site '{site.name}': it is used in {usage_count} scan(s). "
f"Sites that have been used in scans cannot be deleted."
)
# Safe to delete
self.db.delete(site)
self.db.commit()
logger.info(f"Deleted site {site_id} ('{site.name}')")
def get_site(self, site_id: int) -> Optional[Dict[str, Any]]:
"""
Get site details with all CIDRs and IP overrides.
Args:
site_id: Site ID to retrieve
Returns:
Dictionary with site data, or None if not found
"""
site = (
self.db.query(Site)
.options(
joinedload(Site.cidrs).joinedload(SiteCIDR.ips)
)
.filter(Site.id == site_id)
.first()
)
if not site:
return None
return self._site_to_dict(site)
def get_site_by_name(self, name: str) -> Optional[Dict[str, Any]]:
"""
Get site details by name.
Args:
name: Site name to retrieve
Returns:
Dictionary with site data, or None if not found
"""
site = (
self.db.query(Site)
.options(
joinedload(Site.cidrs).joinedload(SiteCIDR.ips)
)
.filter(Site.name == name)
.first()
)
if not site:
return None
return self._site_to_dict(site)
def list_sites(self, page: int = 1, per_page: int = 20) -> PaginatedResult:
"""
List all sites with pagination.
Args:
page: Page number (1-indexed)
per_page: Number of items per page
Returns:
PaginatedResult with site data
"""
query = (
self.db.query(Site)
.options(joinedload(Site.cidrs))
.order_by(Site.name)
)
return paginate(query, page, per_page, self._site_to_dict)
def list_all_sites(self) -> List[Dict[str, Any]]:
"""
List all sites without pagination (for dropdowns, etc.).
Returns:
List of site dictionaries
"""
sites = (
self.db.query(Site)
.options(joinedload(Site.cidrs))
.order_by(Site.name)
.all()
)
return [self._site_to_dict(site) for site in sites]
def add_cidr(self, site_id: int, cidr: str, expected_ping: Optional[bool] = None,
expected_tcp_ports: Optional[List[int]] = None,
expected_udp_ports: Optional[List[int]] = None) -> Dict[str, Any]:
"""
Add a CIDR range to a site.
Args:
site_id: Site ID
cidr: CIDR notation (e.g., "10.0.0.0/24")
expected_ping: Expected ping response for IPs in this CIDR
expected_tcp_ports: List of expected TCP ports
expected_udp_ports: List of expected UDP ports
Returns:
Dictionary with CIDR data
Raises:
ValueError: If site not found, CIDR is invalid, or already exists
"""
site = self.db.query(Site).filter(Site.id == site_id).first()
if not site:
raise ValueError(f"Site with id {site_id} not found")
# Validate CIDR format
try:
ipaddress.ip_network(cidr, strict=False)
except ValueError as e:
raise ValueError(f"Invalid CIDR notation '{cidr}': {str(e)}")
# Check for duplicate CIDR
existing = (
self.db.query(SiteCIDR)
.filter(SiteCIDR.site_id == site_id, SiteCIDR.cidr == cidr)
.first()
)
if existing:
raise ValueError(f"CIDR '{cidr}' already exists for this site")
# Create CIDR
cidr_obj = SiteCIDR(
site_id=site_id,
cidr=cidr,
expected_ping=expected_ping,
expected_tcp_ports=json.dumps(expected_tcp_ports or []),
expected_udp_ports=json.dumps(expected_udp_ports or []),
created_at=datetime.utcnow()
)
self.db.add(cidr_obj)
site.updated_at = datetime.utcnow()
self.db.commit()
self.db.refresh(cidr_obj)
logger.info(f"Added CIDR '{cidr}' to site {site_id} ('{site.name}')")
return self._cidr_to_dict(cidr_obj)
def remove_cidr(self, site_id: int, cidr_id: int) -> None:
"""
Remove a CIDR range from a site.
Prevents removal if it's the last CIDR (sites must have at least one CIDR).
Args:
site_id: Site ID
cidr_id: CIDR ID to remove
Raises:
ValueError: If CIDR not found or it's the last CIDR
"""
site = self.db.query(Site).filter(Site.id == site_id).first()
if not site:
raise ValueError(f"Site with id {site_id} not found")
cidr = (
self.db.query(SiteCIDR)
.filter(SiteCIDR.id == cidr_id, SiteCIDR.site_id == site_id)
.first()
)
if not cidr:
raise ValueError(f"CIDR with id {cidr_id} not found for site {site_id}")
# Check if this is the last CIDR
cidr_count = (
self.db.query(func.count(SiteCIDR.id))
.filter(SiteCIDR.site_id == site_id)
.scalar()
)
if cidr_count <= 1:
raise ValueError(
f"Cannot remove CIDR '{cidr.cidr}': site must have at least one CIDR range"
)
self.db.delete(cidr)
site.updated_at = datetime.utcnow()
self.db.commit()
logger.info(f"Removed CIDR '{cidr.cidr}' from site {site_id} ('{site.name}')")
def add_ip_override(self, cidr_id: int, ip_address: str,
expected_ping: Optional[bool] = None,
expected_tcp_ports: Optional[List[int]] = None,
expected_udp_ports: Optional[List[int]] = None) -> Dict[str, Any]:
"""
Add an IP-level expectation override within a CIDR.
Args:
cidr_id: CIDR ID
ip_address: IP address to override
expected_ping: Override ping expectation
expected_tcp_ports: Override TCP ports expectation
expected_udp_ports: Override UDP ports expectation
Returns:
Dictionary with IP override data
Raises:
ValueError: If CIDR not found, IP is invalid, or not in CIDR range
"""
cidr = self.db.query(SiteCIDR).filter(SiteCIDR.id == cidr_id).first()
if not cidr:
raise ValueError(f"CIDR with id {cidr_id} not found")
# Validate IP format
try:
ip_obj = ipaddress.ip_address(ip_address)
except ValueError as e:
raise ValueError(f"Invalid IP address '{ip_address}': {str(e)}")
# Validate IP is within CIDR range
network = ipaddress.ip_network(cidr.cidr, strict=False)
if ip_obj not in network:
raise ValueError(f"IP address '{ip_address}' is not within CIDR '{cidr.cidr}'")
# Check for duplicate
existing = (
self.db.query(SiteIP)
.filter(SiteIP.site_cidr_id == cidr_id, SiteIP.ip_address == ip_address)
.first()
)
if existing:
raise ValueError(f"IP override for '{ip_address}' already exists in this CIDR")
# Create IP override
ip_override = SiteIP(
site_cidr_id=cidr_id,
ip_address=ip_address,
expected_ping=expected_ping,
expected_tcp_ports=json.dumps(expected_tcp_ports or []),
expected_udp_ports=json.dumps(expected_udp_ports or []),
created_at=datetime.utcnow()
)
self.db.add(ip_override)
self.db.commit()
self.db.refresh(ip_override)
logger.info(f"Added IP override '{ip_address}' to CIDR {cidr_id} ('{cidr.cidr}')")
return self._ip_override_to_dict(ip_override)
def remove_ip_override(self, cidr_id: int, ip_id: int) -> None:
"""
Remove an IP-level override.
Args:
cidr_id: CIDR ID
ip_id: IP override ID to remove
Raises:
ValueError: If IP override not found
"""
ip_override = (
self.db.query(SiteIP)
.filter(SiteIP.id == ip_id, SiteIP.site_cidr_id == cidr_id)
.first()
)
if not ip_override:
raise ValueError(f"IP override with id {ip_id} not found for CIDR {cidr_id}")
ip_address = ip_override.ip_address
self.db.delete(ip_override)
self.db.commit()
logger.info(f"Removed IP override '{ip_address}' from CIDR {cidr_id}")
def get_scan_usage(self, site_id: int) -> List[Dict[str, Any]]:
"""
Get list of scans that use this site.
Args:
site_id: Site ID
Returns:
List of scan dictionaries
"""
from web.models import Scan # Import here to avoid circular dependency
associations = (
self.db.query(ScanSiteAssociation)
.options(joinedload(ScanSiteAssociation.scan))
.filter(ScanSiteAssociation.site_id == site_id)
.all()
)
return [
{
'id': assoc.scan.id,
'title': assoc.scan.title,
'timestamp': assoc.scan.timestamp.isoformat() if assoc.scan.timestamp else None,
'status': assoc.scan.status
}
for assoc in associations
]
# Private helper methods
def _add_cidr_to_site(self, site: Site, cidr_data: Dict[str, Any]) -> SiteCIDR:
"""Helper to add CIDR during site creation."""
cidr = cidr_data.get('cidr')
if not cidr:
raise ValueError("CIDR 'cidr' field is required")
# Validate CIDR format
try:
ipaddress.ip_network(cidr, strict=False)
except ValueError as e:
raise ValueError(f"Invalid CIDR notation '{cidr}': {str(e)}")
cidr_obj = SiteCIDR(
site_id=site.id,
cidr=cidr,
expected_ping=cidr_data.get('expected_ping'),
expected_tcp_ports=json.dumps(cidr_data.get('expected_tcp_ports', [])),
expected_udp_ports=json.dumps(cidr_data.get('expected_udp_ports', [])),
created_at=datetime.utcnow()
)
self.db.add(cidr_obj)
return cidr_obj
def _site_to_dict(self, site: Site) -> Dict[str, Any]:
"""Convert Site model to dictionary."""
return {
'id': site.id,
'name': site.name,
'description': site.description,
'created_at': site.created_at.isoformat() if site.created_at else None,
'updated_at': site.updated_at.isoformat() if site.updated_at else None,
'cidrs': [self._cidr_to_dict(cidr) for cidr in site.cidrs] if hasattr(site, 'cidrs') else []
}
def _cidr_to_dict(self, cidr: SiteCIDR) -> Dict[str, Any]:
"""Convert SiteCIDR model to dictionary."""
return {
'id': cidr.id,
'site_id': cidr.site_id,
'cidr': cidr.cidr,
'expected_ping': cidr.expected_ping,
'expected_tcp_ports': json.loads(cidr.expected_tcp_ports) if cidr.expected_tcp_ports else [],
'expected_udp_ports': json.loads(cidr.expected_udp_ports) if cidr.expected_udp_ports else [],
'created_at': cidr.created_at.isoformat() if cidr.created_at else None,
'ip_overrides': [self._ip_override_to_dict(ip) for ip in cidr.ips] if hasattr(cidr, 'ips') else []
}
def _ip_override_to_dict(self, ip: SiteIP) -> Dict[str, Any]:
"""Convert SiteIP model to dictionary."""
return {
'id': ip.id,
'site_cidr_id': ip.site_cidr_id,
'ip_address': ip.ip_address,
'expected_ping': ip.expected_ping,
'expected_tcp_ports': json.loads(ip.expected_tcp_ports) if ip.expected_tcp_ports else [],
'expected_udp_ports': json.loads(ip.expected_udp_ports) if ip.expected_udp_ports else [],
'created_at': ip.created_at.isoformat() if ip.created_at else None
}