532 lines
17 KiB
Python
532 lines
17 KiB
Python
"""
|
|
Site service for managing reusable site definitions.
|
|
|
|
This service handles the business logic for creating, updating, and managing
|
|
sites with their associated CIDR ranges and IP-level overrides.
|
|
"""
|
|
|
|
import ipaddress
|
|
import json
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from sqlalchemy import func
|
|
from sqlalchemy.orm import Session, joinedload
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
from web.models import (
|
|
Site, SiteCIDR, SiteIP, ScanSiteAssociation
|
|
)
|
|
from web.utils.pagination import paginate, PaginatedResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SiteService:
|
|
"""
|
|
Service for managing reusable site definitions.
|
|
|
|
Handles site lifecycle: creation, updates, deletion (with safety checks),
|
|
CIDR management, and IP-level overrides.
|
|
"""
|
|
|
|
def __init__(self, db_session: Session):
|
|
"""
|
|
Initialize site service.
|
|
|
|
Args:
|
|
db_session: SQLAlchemy database session
|
|
"""
|
|
self.db = db_session
|
|
|
|
def create_site(self, name: str, description: Optional[str] = None,
|
|
cidrs: Optional[List[Dict[str, Any]]] = None) -> Dict[str, Any]:
|
|
"""
|
|
Create a new site with optional CIDR ranges.
|
|
|
|
Args:
|
|
name: Unique site name
|
|
description: Optional site description
|
|
cidrs: List of CIDR definitions with format:
|
|
[{"cidr": "10.0.0.0/24", "expected_ping": true,
|
|
"expected_tcp_ports": [22, 80], "expected_udp_ports": [53]}]
|
|
|
|
Returns:
|
|
Dictionary with created site data
|
|
|
|
Raises:
|
|
ValueError: If site name already exists or validation fails
|
|
"""
|
|
# Validate site name is unique
|
|
existing = self.db.query(Site).filter(Site.name == name).first()
|
|
if existing:
|
|
raise ValueError(f"Site with name '{name}' already exists")
|
|
|
|
# Validate we have at least one CIDR if provided
|
|
if cidrs is not None and len(cidrs) == 0:
|
|
raise ValueError("Site must have at least one CIDR range")
|
|
|
|
# Create site
|
|
site = Site(
|
|
name=name,
|
|
description=description,
|
|
created_at=datetime.utcnow(),
|
|
updated_at=datetime.utcnow()
|
|
)
|
|
|
|
self.db.add(site)
|
|
self.db.flush() # Get site.id without committing
|
|
|
|
# Add CIDRs if provided
|
|
if cidrs:
|
|
for cidr_data in cidrs:
|
|
self._add_cidr_to_site(site, cidr_data)
|
|
|
|
self.db.commit()
|
|
self.db.refresh(site)
|
|
|
|
logger.info(f"Created site '{name}' (id={site.id}) with {len(cidrs or [])} CIDR(s)")
|
|
|
|
return self._site_to_dict(site)
|
|
|
|
def update_site(self, site_id: int, name: Optional[str] = None,
|
|
description: Optional[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
Update site metadata (name and/or description).
|
|
|
|
Args:
|
|
site_id: Site ID to update
|
|
name: New site name (must be unique)
|
|
description: New description
|
|
|
|
Returns:
|
|
Dictionary with updated site data
|
|
|
|
Raises:
|
|
ValueError: If site not found or name already exists
|
|
"""
|
|
site = self.db.query(Site).filter(Site.id == site_id).first()
|
|
if not site:
|
|
raise ValueError(f"Site with id {site_id} not found")
|
|
|
|
# Update name if provided
|
|
if name is not None and name != site.name:
|
|
# Check uniqueness
|
|
existing = self.db.query(Site).filter(
|
|
Site.name == name,
|
|
Site.id != site_id
|
|
).first()
|
|
if existing:
|
|
raise ValueError(f"Site with name '{name}' already exists")
|
|
site.name = name
|
|
|
|
# Update description if provided
|
|
if description is not None:
|
|
site.description = description
|
|
|
|
site.updated_at = datetime.utcnow()
|
|
|
|
self.db.commit()
|
|
self.db.refresh(site)
|
|
|
|
logger.info(f"Updated site {site_id} ('{site.name}')")
|
|
|
|
return self._site_to_dict(site)
|
|
|
|
def delete_site(self, site_id: int) -> None:
|
|
"""
|
|
Delete a site.
|
|
|
|
Prevents deletion if the site is used in any scan (per user requirement).
|
|
|
|
Args:
|
|
site_id: Site ID to delete
|
|
|
|
Raises:
|
|
ValueError: If site not found or is used in scans
|
|
"""
|
|
site = self.db.query(Site).filter(Site.id == site_id).first()
|
|
if not site:
|
|
raise ValueError(f"Site with id {site_id} not found")
|
|
|
|
# Check if site is used in any scans
|
|
usage_count = (
|
|
self.db.query(func.count(ScanSiteAssociation.id))
|
|
.filter(ScanSiteAssociation.site_id == site_id)
|
|
.scalar()
|
|
)
|
|
|
|
if usage_count > 0:
|
|
raise ValueError(
|
|
f"Cannot delete site '{site.name}': it is used in {usage_count} scan(s). "
|
|
f"Sites that have been used in scans cannot be deleted."
|
|
)
|
|
|
|
# Safe to delete
|
|
self.db.delete(site)
|
|
self.db.commit()
|
|
|
|
logger.info(f"Deleted site {site_id} ('{site.name}')")
|
|
|
|
def get_site(self, site_id: int) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get site details with all CIDRs and IP overrides.
|
|
|
|
Args:
|
|
site_id: Site ID to retrieve
|
|
|
|
Returns:
|
|
Dictionary with site data, or None if not found
|
|
"""
|
|
site = (
|
|
self.db.query(Site)
|
|
.options(
|
|
joinedload(Site.cidrs).joinedload(SiteCIDR.ips)
|
|
)
|
|
.filter(Site.id == site_id)
|
|
.first()
|
|
)
|
|
|
|
if not site:
|
|
return None
|
|
|
|
return self._site_to_dict(site)
|
|
|
|
def get_site_by_name(self, name: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get site details by name.
|
|
|
|
Args:
|
|
name: Site name to retrieve
|
|
|
|
Returns:
|
|
Dictionary with site data, or None if not found
|
|
"""
|
|
site = (
|
|
self.db.query(Site)
|
|
.options(
|
|
joinedload(Site.cidrs).joinedload(SiteCIDR.ips)
|
|
)
|
|
.filter(Site.name == name)
|
|
.first()
|
|
)
|
|
|
|
if not site:
|
|
return None
|
|
|
|
return self._site_to_dict(site)
|
|
|
|
def list_sites(self, page: int = 1, per_page: int = 20) -> PaginatedResult:
|
|
"""
|
|
List all sites with pagination.
|
|
|
|
Args:
|
|
page: Page number (1-indexed)
|
|
per_page: Number of items per page
|
|
|
|
Returns:
|
|
PaginatedResult with site data
|
|
"""
|
|
query = (
|
|
self.db.query(Site)
|
|
.options(joinedload(Site.cidrs))
|
|
.order_by(Site.name)
|
|
)
|
|
|
|
return paginate(query, page, per_page, self._site_to_dict)
|
|
|
|
def list_all_sites(self) -> List[Dict[str, Any]]:
|
|
"""
|
|
List all sites without pagination (for dropdowns, etc.).
|
|
|
|
Returns:
|
|
List of site dictionaries
|
|
"""
|
|
sites = (
|
|
self.db.query(Site)
|
|
.options(joinedload(Site.cidrs))
|
|
.order_by(Site.name)
|
|
.all()
|
|
)
|
|
|
|
return [self._site_to_dict(site) for site in sites]
|
|
|
|
def add_cidr(self, site_id: int, cidr: str, expected_ping: Optional[bool] = None,
|
|
expected_tcp_ports: Optional[List[int]] = None,
|
|
expected_udp_ports: Optional[List[int]] = None) -> Dict[str, Any]:
|
|
"""
|
|
Add a CIDR range to a site.
|
|
|
|
Args:
|
|
site_id: Site ID
|
|
cidr: CIDR notation (e.g., "10.0.0.0/24")
|
|
expected_ping: Expected ping response for IPs in this CIDR
|
|
expected_tcp_ports: List of expected TCP ports
|
|
expected_udp_ports: List of expected UDP ports
|
|
|
|
Returns:
|
|
Dictionary with CIDR data
|
|
|
|
Raises:
|
|
ValueError: If site not found, CIDR is invalid, or already exists
|
|
"""
|
|
site = self.db.query(Site).filter(Site.id == site_id).first()
|
|
if not site:
|
|
raise ValueError(f"Site with id {site_id} not found")
|
|
|
|
# Validate CIDR format
|
|
try:
|
|
ipaddress.ip_network(cidr, strict=False)
|
|
except ValueError as e:
|
|
raise ValueError(f"Invalid CIDR notation '{cidr}': {str(e)}")
|
|
|
|
# Check for duplicate CIDR
|
|
existing = (
|
|
self.db.query(SiteCIDR)
|
|
.filter(SiteCIDR.site_id == site_id, SiteCIDR.cidr == cidr)
|
|
.first()
|
|
)
|
|
if existing:
|
|
raise ValueError(f"CIDR '{cidr}' already exists for this site")
|
|
|
|
# Create CIDR
|
|
cidr_obj = SiteCIDR(
|
|
site_id=site_id,
|
|
cidr=cidr,
|
|
expected_ping=expected_ping,
|
|
expected_tcp_ports=json.dumps(expected_tcp_ports or []),
|
|
expected_udp_ports=json.dumps(expected_udp_ports or []),
|
|
created_at=datetime.utcnow()
|
|
)
|
|
|
|
self.db.add(cidr_obj)
|
|
site.updated_at = datetime.utcnow()
|
|
self.db.commit()
|
|
self.db.refresh(cidr_obj)
|
|
|
|
logger.info(f"Added CIDR '{cidr}' to site {site_id} ('{site.name}')")
|
|
|
|
return self._cidr_to_dict(cidr_obj)
|
|
|
|
def remove_cidr(self, site_id: int, cidr_id: int) -> None:
|
|
"""
|
|
Remove a CIDR range from a site.
|
|
|
|
Prevents removal if it's the last CIDR (sites must have at least one CIDR).
|
|
|
|
Args:
|
|
site_id: Site ID
|
|
cidr_id: CIDR ID to remove
|
|
|
|
Raises:
|
|
ValueError: If CIDR not found or it's the last CIDR
|
|
"""
|
|
site = self.db.query(Site).filter(Site.id == site_id).first()
|
|
if not site:
|
|
raise ValueError(f"Site with id {site_id} not found")
|
|
|
|
cidr = (
|
|
self.db.query(SiteCIDR)
|
|
.filter(SiteCIDR.id == cidr_id, SiteCIDR.site_id == site_id)
|
|
.first()
|
|
)
|
|
if not cidr:
|
|
raise ValueError(f"CIDR with id {cidr_id} not found for site {site_id}")
|
|
|
|
# Check if this is the last CIDR
|
|
cidr_count = (
|
|
self.db.query(func.count(SiteCIDR.id))
|
|
.filter(SiteCIDR.site_id == site_id)
|
|
.scalar()
|
|
)
|
|
|
|
if cidr_count <= 1:
|
|
raise ValueError(
|
|
f"Cannot remove CIDR '{cidr.cidr}': site must have at least one CIDR range"
|
|
)
|
|
|
|
self.db.delete(cidr)
|
|
site.updated_at = datetime.utcnow()
|
|
self.db.commit()
|
|
|
|
logger.info(f"Removed CIDR '{cidr.cidr}' from site {site_id} ('{site.name}')")
|
|
|
|
def add_ip_override(self, cidr_id: int, ip_address: str,
|
|
expected_ping: Optional[bool] = None,
|
|
expected_tcp_ports: Optional[List[int]] = None,
|
|
expected_udp_ports: Optional[List[int]] = None) -> Dict[str, Any]:
|
|
"""
|
|
Add an IP-level expectation override within a CIDR.
|
|
|
|
Args:
|
|
cidr_id: CIDR ID
|
|
ip_address: IP address to override
|
|
expected_ping: Override ping expectation
|
|
expected_tcp_ports: Override TCP ports expectation
|
|
expected_udp_ports: Override UDP ports expectation
|
|
|
|
Returns:
|
|
Dictionary with IP override data
|
|
|
|
Raises:
|
|
ValueError: If CIDR not found, IP is invalid, or not in CIDR range
|
|
"""
|
|
cidr = self.db.query(SiteCIDR).filter(SiteCIDR.id == cidr_id).first()
|
|
if not cidr:
|
|
raise ValueError(f"CIDR with id {cidr_id} not found")
|
|
|
|
# Validate IP format
|
|
try:
|
|
ip_obj = ipaddress.ip_address(ip_address)
|
|
except ValueError as e:
|
|
raise ValueError(f"Invalid IP address '{ip_address}': {str(e)}")
|
|
|
|
# Validate IP is within CIDR range
|
|
network = ipaddress.ip_network(cidr.cidr, strict=False)
|
|
if ip_obj not in network:
|
|
raise ValueError(f"IP address '{ip_address}' is not within CIDR '{cidr.cidr}'")
|
|
|
|
# Check for duplicate
|
|
existing = (
|
|
self.db.query(SiteIP)
|
|
.filter(SiteIP.site_cidr_id == cidr_id, SiteIP.ip_address == ip_address)
|
|
.first()
|
|
)
|
|
if existing:
|
|
raise ValueError(f"IP override for '{ip_address}' already exists in this CIDR")
|
|
|
|
# Create IP override
|
|
ip_override = SiteIP(
|
|
site_cidr_id=cidr_id,
|
|
ip_address=ip_address,
|
|
expected_ping=expected_ping,
|
|
expected_tcp_ports=json.dumps(expected_tcp_ports or []),
|
|
expected_udp_ports=json.dumps(expected_udp_ports or []),
|
|
created_at=datetime.utcnow()
|
|
)
|
|
|
|
self.db.add(ip_override)
|
|
self.db.commit()
|
|
self.db.refresh(ip_override)
|
|
|
|
logger.info(f"Added IP override '{ip_address}' to CIDR {cidr_id} ('{cidr.cidr}')")
|
|
|
|
return self._ip_override_to_dict(ip_override)
|
|
|
|
def remove_ip_override(self, cidr_id: int, ip_id: int) -> None:
|
|
"""
|
|
Remove an IP-level override.
|
|
|
|
Args:
|
|
cidr_id: CIDR ID
|
|
ip_id: IP override ID to remove
|
|
|
|
Raises:
|
|
ValueError: If IP override not found
|
|
"""
|
|
ip_override = (
|
|
self.db.query(SiteIP)
|
|
.filter(SiteIP.id == ip_id, SiteIP.site_cidr_id == cidr_id)
|
|
.first()
|
|
)
|
|
if not ip_override:
|
|
raise ValueError(f"IP override with id {ip_id} not found for CIDR {cidr_id}")
|
|
|
|
ip_address = ip_override.ip_address
|
|
self.db.delete(ip_override)
|
|
self.db.commit()
|
|
|
|
logger.info(f"Removed IP override '{ip_address}' from CIDR {cidr_id}")
|
|
|
|
def get_scan_usage(self, site_id: int) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get list of scans that use this site.
|
|
|
|
Args:
|
|
site_id: Site ID
|
|
|
|
Returns:
|
|
List of scan dictionaries
|
|
"""
|
|
from web.models import Scan # Import here to avoid circular dependency
|
|
|
|
associations = (
|
|
self.db.query(ScanSiteAssociation)
|
|
.options(joinedload(ScanSiteAssociation.scan))
|
|
.filter(ScanSiteAssociation.site_id == site_id)
|
|
.all()
|
|
)
|
|
|
|
return [
|
|
{
|
|
'id': assoc.scan.id,
|
|
'title': assoc.scan.title,
|
|
'timestamp': assoc.scan.timestamp.isoformat() if assoc.scan.timestamp else None,
|
|
'status': assoc.scan.status
|
|
}
|
|
for assoc in associations
|
|
]
|
|
|
|
# Private helper methods
|
|
|
|
def _add_cidr_to_site(self, site: Site, cidr_data: Dict[str, Any]) -> SiteCIDR:
|
|
"""Helper to add CIDR during site creation."""
|
|
cidr = cidr_data.get('cidr')
|
|
if not cidr:
|
|
raise ValueError("CIDR 'cidr' field is required")
|
|
|
|
# Validate CIDR format
|
|
try:
|
|
ipaddress.ip_network(cidr, strict=False)
|
|
except ValueError as e:
|
|
raise ValueError(f"Invalid CIDR notation '{cidr}': {str(e)}")
|
|
|
|
cidr_obj = SiteCIDR(
|
|
site_id=site.id,
|
|
cidr=cidr,
|
|
expected_ping=cidr_data.get('expected_ping'),
|
|
expected_tcp_ports=json.dumps(cidr_data.get('expected_tcp_ports', [])),
|
|
expected_udp_ports=json.dumps(cidr_data.get('expected_udp_ports', [])),
|
|
created_at=datetime.utcnow()
|
|
)
|
|
|
|
self.db.add(cidr_obj)
|
|
return cidr_obj
|
|
|
|
def _site_to_dict(self, site: Site) -> Dict[str, Any]:
|
|
"""Convert Site model to dictionary."""
|
|
return {
|
|
'id': site.id,
|
|
'name': site.name,
|
|
'description': site.description,
|
|
'created_at': site.created_at.isoformat() if site.created_at else None,
|
|
'updated_at': site.updated_at.isoformat() if site.updated_at else None,
|
|
'cidrs': [self._cidr_to_dict(cidr) for cidr in site.cidrs] if hasattr(site, 'cidrs') else []
|
|
}
|
|
|
|
def _cidr_to_dict(self, cidr: SiteCIDR) -> Dict[str, Any]:
|
|
"""Convert SiteCIDR model to dictionary."""
|
|
return {
|
|
'id': cidr.id,
|
|
'site_id': cidr.site_id,
|
|
'cidr': cidr.cidr,
|
|
'expected_ping': cidr.expected_ping,
|
|
'expected_tcp_ports': json.loads(cidr.expected_tcp_ports) if cidr.expected_tcp_ports else [],
|
|
'expected_udp_ports': json.loads(cidr.expected_udp_ports) if cidr.expected_udp_ports else [],
|
|
'created_at': cidr.created_at.isoformat() if cidr.created_at else None,
|
|
'ip_overrides': [self._ip_override_to_dict(ip) for ip in cidr.ips] if hasattr(cidr, 'ips') else []
|
|
}
|
|
|
|
def _ip_override_to_dict(self, ip: SiteIP) -> Dict[str, Any]:
|
|
"""Convert SiteIP model to dictionary."""
|
|
return {
|
|
'id': ip.id,
|
|
'site_cidr_id': ip.site_cidr_id,
|
|
'ip_address': ip.ip_address,
|
|
'expected_ping': ip.expected_ping,
|
|
'expected_tcp_ports': json.loads(ip.expected_tcp_ports) if ip.expected_tcp_ports else [],
|
|
'expected_udp_ports': json.loads(ip.expected_udp_ports) if ip.expected_udp_ports else [],
|
|
'created_at': ip.created_at.isoformat() if ip.created_at else None
|
|
}
|