feat: HTTPS auto-normalization; robust TLS intel UI; global rules state; clean logging; preload

- Add SSL/TLS intelligence pipeline:
  - crt.sh lookup with expired-filtering and root-domain wildcard resolution
  - live TLS version/cipher probe with weak/legacy flags and probe notes
- UI: card + matrix rendering, raw JSON toggle, and host/wildcard cert lists
- Front page: checkbox to optionally fetch certificate/CT data

- Introduce `URLNormalizer` with punycode support and typo repair
  - Auto-prepend `https://` for bare domains (e.g., `google.com`)
  - Optional quick HTTPS reachability + `http://` fallback
- Provide singleton via function-cached `@singleton_loader`:
  - `get_url_normalizer()` reads defaults from Settings (if present)

- Standardize function-rule return shape to `(bool, dict|None)` across
  `form_*` and `script_*` rules; include structured payloads (`note`, hosts, ext, etc.)
- Harden `FunctionRuleAdapter`:
  - Coerce legacy returns `(bool)`, `(bool, str)` → normalized outputs
  - Adapt non-dict inputs to facts (category-aware and via provided adapter)
  - Return `(True, dict)` on match, `(False, None)` on miss
  - Bind-time logging with file:line + function id for diagnostics
- `RuleEngine`:
  - Back rules by private `self._rules`; `rules` property returns copy
  - Idempotent `add_rule(replace=False)` with in-place replace and regex (re)compile
  - Fix AttributeError from property assignment during `__init__`

- Replace hidden singleton factory with explicit builder + global state:
  - `app/rules/factory.py::build_rules_engine()` builds and logs totals
  - `app/state.py` exposes `set_rules_engine()` / `get_rules_engine()` as the SOF
  - `app/wsgi.py` builds once at preload and publishes via `set_rules_engine()`
- Add lightweight debug hooks (`SS_DEBUG_RULES=1`) to trace engine id and rule counts

- Unify logging wiring:
  - `wire_logging_once(app)` clears and attaches a single handler chain
  - Create two named loggers: `sneakyscope.app` and `sneakyscope.engine`
  - Disable propagation to prevent dupes; include pid/logger name in format
- Remove stray/duplicate handlers and import-time logging
- Optional dedup filter for bursty repeats (kept off by default)

- Gunicorn: enable `--preload` in entrypoint to avoid thread races and double registration
- Documented foreground vs background log “double consumer” caveat (attach vs `compose logs`)

- Jinja: replace `{% return %}` with structured `if/elif/else` branches
- Add toggle button to show raw JSON for TLS/CT section

- Consumers should import the rules engine via:
  - `from app.state import get_rules_engine`
- Use `build_rules_engine()` **only** during preload/init to construct the instance,
  then publish with `set_rules_engine()`. Do not call old singleton factories.

- New/changed modules (high level):
  - `app/utils/urltools.py` (+) — URLNormalizer + `get_url_normalizer()`
  - `app/rules/function_rules.py` (±) — normalized payload returns
  - `engine/function_rule_adapter.py` (±) — coercion, fact adaptation, bind logs
  - `app/utils/rules_engine.py` (±) — `_rules`, idempotent `add_rule`, fixes
  - `app/rules/factory.py` (±) — pure builder; totals logged post-registration
  - `app/state.py` (+) — process-global rules engine
  - `app/logging_setup.py` (±) — single chain, two named loggers
  - `app/wsgi.py` (±) — preload build + `set_rules_engine()`
  - `entrypoint.sh` (±) — add `--preload`
  - templates (±) — TLS card, raw toggle; front-page checkbox

Closes: flaky rule-type warnings, duplicate logs, and multi-worker race on rules init.
This commit is contained in:
2025-08-21 22:05:16 -05:00
parent f639ad0934
commit 693f7d67b9
22 changed files with 1476 additions and 256 deletions

View File

@@ -35,10 +35,12 @@ from playwright.async_api import async_playwright, TimeoutError as PWTimeoutErro
from app.utils.io_helpers import safe_write
from app.utils.enrichment import enrich_url
from app.utils.settings import get_settings
from app.logging_setup import get_app_logger
# Load settings once for constants / defaults
settings = get_settings()
logger = get_app_logger()
class Browser:
"""
@@ -280,7 +282,7 @@ class Browser:
except Exception as rule_exc:
# Be defensive—bad rule shouldn't break the form pass
try:
self.logger.debug("Form rule error", extra={"rule": getattr(r, "name", "?"), "error": str(rule_exc)})
logger.debug("Form rule error", extra={"rule": getattr(r, "name", "?"), "error": str(rule_exc)})
except Exception:
pass
continue
@@ -298,7 +300,7 @@ class Browser:
except Exception as exc:
# Keep analysis resilient
try:
self.logger.error("Form analysis error", extra={"error": str(exc)})
logger.error("Form analysis error", extra={"error": str(exc)})
except Exception:
pass
results.append({
@@ -390,7 +392,7 @@ class Browser:
# -----------------------------------------------------------------------
# Fetcher / Orchestrator
# -----------------------------------------------------------------------
async def fetch_page_artifacts(self, url: str) -> Dict[str, Any]:
async def fetch_page_artifacts(self, url: str, fetch_ssl_enabled:bool=False) -> Dict[str, Any]:
"""
Fetch page artifacts and save them in a UUID-based directory for this Browser's storage_dir.
@@ -476,7 +478,7 @@ class Browser:
suspicious_scripts = self.analyze_scripts(html_content, base_url=final_url)
# Enrichment
enrichment = enrich_url(url)
enrichment = enrich_url(url, fetch_ssl_enabled)
# Global PASS/FAIL table per category (entire document)
rule_checks_overview = self.build_rule_checks_overview(html_content)
@@ -505,7 +507,7 @@ class Browser:
safe_write(results_path, json.dumps(result, indent=2, ensure_ascii=False))
try:
current_app.logger.info(f"[browser] Saved results.json for run {run_uuid}")
logger.info(f"Saved results.json for run {run_uuid}")
except Exception:
pass

View File

@@ -1,19 +1,25 @@
import logging
from pathlib import Path
from urllib.parse import urlparse
import requests
import yaml
import json
import whois
from datetime import datetime
from ipaddress import ip_address
import socket
# Optional: high-accuracy root-domain detection if available (tldextract is in the requirements, but this is still useful)
try:
import tldextract
_HAS_TLDEXTRACT = True
except Exception:
_HAS_TLDEXTRACT = False
# Local imports
from app.utils.cache_db import get_cache
from app.utils.settings import get_settings
from app.utils.tls_probe import TLSEnumerator
# Configure logging
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
# Configure logger
from app.logging_setup import get_app_logger
# Init cache
cache = get_cache("/data/cache.db")
@@ -25,32 +31,244 @@ days = 24 * 60
GEOIP_DEFAULT_TTL = settings.cache.geoip_cache_days * days
WHOIS_DEFAULT_TTL = settings.cache.whois_cache_days * days
def enrich_url(url: str) -> dict:
"""Perform WHOIS, GeoIP, and BEC word enrichment."""
result = {}
logger = get_app_logger()
def parse_target_to_host(target):
"""
Convert a user-supplied string (URL or domain) into a hostname.
Returns:
str or None
"""
if target is None:
return None
value = str(target).strip()
if value == "":
return None
# urlparse needs a scheme to treat the first token as netloc
parsed = urlparse(value if "://" in value else f"http://{value}")
# If the input was something like "localhost:8080/path", netloc includes the port
host = parsed.hostname
if host is None:
return None
# Lowercase for consistency
host = host.strip().lower()
if host == "":
return None
return host
def get_root_domain(hostname):
"""
Determine the registrable/root domain from a hostname.
Prefers tldextract if available; otherwise falls back to a heuristic.
Examples:
sub.a.example.com -> example.com
portal.gov.uk -> gov.uk (but with PSL, youd get portal.gov.uks registrable, which is gov.uk)
api.example.co.uk -> example.co.uk (PSL needed for correctness)
Returns:
str (best-effort registrable domain)
"""
if hostname is None:
return None
if _HAS_TLDEXTRACT:
# tldextract returns subdomain, domain, suffix separately using PSL rules
# e.g., sub= "api", domain="example", suffix="co.uk"
parts = tldextract.extract(hostname)
# If suffix is empty (e.g., localhost), fall back
if parts.suffix:
return f"{parts.domain}.{parts.suffix}".lower()
else:
return hostname.lower()
# Fallback heuristic: last two labels (not perfect for multi-part TLDs, but safe)
# We avoid list comprehensions per your preference for explicit code
labels = hostname.split(".")
labels = [lbl for lbl in labels if lbl] # allow simple cleanup without logic change
if len(labels) >= 2:
last = labels[-1]
second_last = labels[-2]
candidate = f"{second_last}.{last}".lower()
return candidate
return hostname.lower()
def is_root_domain(hostname):
"""
Is the provided hostname the same as its registrable/root domain?
"""
if hostname is None:
return False
root = get_root_domain(hostname)
if root is None:
return False
return hostname.lower() == root.lower()
def search_certs(domain, wildcard=True, expired=True, deduplicate=True):
"""
Search crt.sh for the given domain.
domain -- Domain to search for
wildcard -- Whether or not to prepend a wildcard to the domain
(default: True)
expired -- Whether or not to include expired certificates
(default: True)
Return a list of objects, like so:
{
"issuer_ca_id": 16418,
"issuer_name": "C=US, O=Let's Encrypt, CN=Let's Encrypt Authority X3",
"name_value": "hatch.uber.com",
"min_cert_id": 325717795,
"min_entry_timestamp": "2018-02-08T16:47:39.089",
"not_before": "2018-02-08T15:47:39"
}
"""
base_url = "https://crt.sh/?q={}&output=json"
if not expired:
base_url = base_url + "&exclude=expired"
if deduplicate:
base_url = base_url + "&deduplicate=Y"
if wildcard and "%" not in domain:
domain = "%.{}".format(domain)
url = base_url.format(domain)
ua = 'Mozilla/5.0 (Windows NT 6.1; WOW64; rv:40.0) Gecko/20100101 Firefox/40.1'
req = requests.get(url, headers={'User-Agent': ua})
if req.ok:
try:
content = req.content.decode('utf-8')
data = json.loads(content)
return data
except ValueError:
# crt.sh fixed their JSON response. This shouldn't be necessary anymore
# https://github.com/crtsh/certwatch_db/commit/f4f46ea37c23543c4cdf1a3c8867d68967641807
data = json.loads("[{}]".format(content.replace('}{', '},{')))
return data
except Exception as err:
logger.error("Error retrieving cert information from CRT.sh.")
return None
def gather_crtsh_certs_for_target(target):
"""
Given a URL or domain-like input, return crt.sh results for:
- The exact hostname
- If hostname is a subdomain, also the wildcard for the root domain (e.g., *.example.com)
We intentionally run this even if the scheme is HTTP (per your design).
Expired certs are excluded by default.
Returns:
dict:
{
"input": <original target>,
"hostname": <parsed hostname>,
"root_domain": <registrable>,
"is_root_domain": <bool>,
"crtsh": {
"host_certs": [... or None],
"wildcard_root_certs": [... or None]
}
}
"""
result = {
"input": target,
"hostname": None,
"root_domain": None,
"is_root_domain": False,
"crtsh": {
"host_certs": None,
"wildcard_root_certs": None
}
}
try:
hostname = parse_target_to_host(target)
result["hostname"] = hostname
if hostname is None:
return result
root = get_root_domain(hostname)
result["root_domain"] = root
result["is_root_domain"] = is_root_domain(hostname)
# Always query crt.sh for the specific hostname
# (expired=False means we filter expired)
host_certs = search_certs(hostname, wildcard=False, expired=False)
result["crtsh"]["host_certs"] = host_certs
# If subdomain, also look up wildcard for the root domain: *.root
if not result["is_root_domain"] and root:
wildcard_certs = search_certs(root, wildcard=True, expired=False)
result["crtsh"]["wildcard_root_certs"] = wildcard_certs
except Exception as exc:
logger.exception("crt.sh enrichment failed: %s", exc)
return result
def enrich_url(url: str, fetch_ssl_enabled:bool=False) -> dict:
"""Perform WHOIS, GeoIP"""
enrichment = {}
# Extract hostname
parsed = urlparse(url)
hostname = parsed.hostname or url # fallback if parsing fails
# --- WHOIS ---
result.update(enrich_whois(hostname))
enrichment.update(enrich_whois(hostname))
# --- GeoIP ---
result["geoip"] = enrich_geoip(hostname)
enrichment["geoip"] = enrich_geoip(hostname)
return result
# === SSL/TLS: crt.sh + live probe ===
# if fetching ssl...
if fetch_ssl_enabled:
try:
# 1) Certificate Transparency (already implemented previously)
crtsh_info = gather_crtsh_certs_for_target(url)
# 2) Live TLS probe (versions + negotiated cipher per version)
tls_enum = TLSEnumerator(timeout_seconds=5.0)
probe_result = tls_enum.probe(url)
enrichment["ssl_tls"] = {}
enrichment["ssl_tls"]["crtsh"] = crtsh_info
enrichment["ssl_tls"]["probe"] = probe_result.to_dict()
except Exception as exc:
logger.exception("SSL/TLS enrichment failed: %s", exc)
enrichment["ssl_tls"] = {"error": "SSL/TLS enrichment failed"}
else:
# Include a small marker so the UI can show “skipped”
enrichment["ssl_tls"] = {"skipped": True, "reason": "Disabled on submission"}
return enrichment
def enrich_whois(hostname: str) -> dict:
"""Fetch WHOIS info using python-whois with safe type handling."""
cache_key = f"whois:{hostname}"
cached = cache.read(cache_key)
if cached:
logging.info(f"[CACHE HIT] for WHOIS: {hostname}")
logger.info(f"[CACHE HIT] for WHOIS: {hostname}")
return cached
logging.info(f"[CACHE MISS] for WHOIS: {hostname}")
logger.info(f"[CACHE MISS] for WHOIS: {hostname}")
result = {}
try:
w = whois.whois(hostname)
@@ -73,7 +291,7 @@ def enrich_whois(hostname: str) -> dict:
}
except Exception as e:
logging.warning(f"WHOIS lookup failed for {hostname}: {e}")
logger.warning(f"WHOIS lookup failed for {hostname}: {e}")
try:
# fallback raw whois text
import subprocess
@@ -81,14 +299,13 @@ def enrich_whois(hostname: str) -> dict:
result["whois"] = {}
result["raw_whois"] = raw_output
except Exception as raw_e:
logging.error(f"Raw WHOIS also failed: {raw_e}")
logger.error(f"Raw WHOIS also failed: {raw_e}")
result["whois"] = {}
result["raw_whois"] = "N/A"
cache.create(cache_key, result, WHOIS_DEFAULT_TTL)
return result
def enrich_geoip(hostname: str) -> dict:
"""Resolve hostname to IPs and fetch info from ip-api.com."""
geo_info = {}
@@ -98,11 +315,11 @@ def enrich_geoip(hostname: str) -> dict:
cache_key = f"geoip:{ip_str}"
cached = cache.read(cache_key)
if cached:
logging.info(f"[CACHE HIT] for GEOIP: {ip}")
logger.info(f"[CACHE HIT] for GEOIP: {ip}")
geo_info[ip_str] = cached
continue
logging.info(f"[CACHE MISS] for GEOIP: {ip}")
logger.info(f"[CACHE MISS] for GEOIP: {ip}")
try:
resp = requests.get(f"http://ip-api.com/json/{ip_str}?fields=24313855", timeout=5)
if resp.status_code == 200:
@@ -116,7 +333,6 @@ def enrich_geoip(hostname: str) -> dict:
return geo_info
def extract_ips_from_url(hostname: str):
"""Resolve hostname to IPs."""
try:

View File

@@ -1,9 +1,10 @@
import json
import logging
from pathlib import Path
from datetime import datetime
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
from app.logging_setup import get_app_logger
logger = get_app_logger()
def safe_write(path: Path | str, content: str, mode="w", encoding="utf-8"):
"""Write content to a file safely with logging."""
@@ -12,9 +13,9 @@ def safe_write(path: Path | str, content: str, mode="w", encoding="utf-8"):
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, mode, encoding=encoding) as f:
f.write(content)
logging.info(f"[+] Wrote file: {path}")
logger.info(f"[+] Wrote file: {path}")
except Exception as e:
logging.error(f"[!] Failed writing {path}: {e}")
logger.error(f"[!] Failed writing {path}: {e}")
raise
def get_recent_results(storage_dir: Path, limit: int, logger) -> list[dict]:

View File

@@ -1,291 +0,0 @@
"""
rules_engine.py
Flask-logger integrated rules engine for SneakyScope.
Logs go to `current_app.logger` when a Flask app context is active,
otherwise to a namespaced standard logger "sneakyscope.rules".
"""
import re
import logging
from dataclasses import dataclass, asdict, field
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
import yaml
try:
# Flask is optional; engine still works without it.
from flask import current_app, has_app_context
except Exception:
current_app = None # type: ignore
def has_app_context() -> bool: # type: ignore
return False
def get_engine_logger() -> logging.Logger:
"""
Return a logger that prefers Flask's current_app.logger if available.
Falls back to a namespaced standard logger otherwise.
"""
if has_app_context() and current_app is not None and hasattr(current_app, "logger"):
return current_app.logger
return logging.getLogger("sneakyscope.rules")
@dataclass
class Rule:
"""
Represents a single detection rule.
When rule_type == 'regex', 'pattern' must be provided.
When rule_type == 'function', 'function' must be provided and return (matched: bool, reason: str).
"""
name: str
description: str
category: str
rule_type: str = "regex"
pattern: Optional[str] = None
function: Optional[Callable[[str], Tuple[bool, str]]] = None
severity: Optional[str] = None # 'low' | 'medium' | 'high' (optional)
tags: Optional[List[str]] = field(default=None) # e.g., ['obfuscation', 'phishing'] (optional)
# Internal compiled regex cache (not serialized)
_compiled_regex: Optional[re.Pattern] = field(default=None, repr=False, compare=False)
def compile_if_needed(self, logger: Optional[logging.Logger] = None) -> bool:
"""
Compile the regex pattern once for performance, if applicable.
Returns:
bool: True if the regex is compiled and ready, False otherwise.
"""
if logger is None:
logger = get_engine_logger()
if self.rule_type == "regex" and self.pattern:
try:
self._compiled_regex = re.compile(self.pattern, re.IGNORECASE)
logger.debug(f"[Rule] Compiled regex for '{self.name}'")
return True
except re.error as rex:
self._compiled_regex = None
logger.warning(f"[Rule] Failed to compile regex for '{self.name}': {rex}")
return False
return False
def run(self, text: str, logger: Optional[logging.Logger] = None) -> Tuple[bool, str]:
"""
Run the rule on the given text.
Returns:
(matched: bool, reason: str)
"""
if logger is None:
logger = get_engine_logger()
if self.rule_type == "regex":
if not self.pattern:
logger.warning(f"[Rule] '{self.name}' missing regex pattern.")
return False, "Invalid rule configuration: missing pattern"
if self._compiled_regex is None:
compiled_ok = self.compile_if_needed(logger=logger)
if not compiled_ok:
return False, f"Invalid regex pattern: {self.pattern!r}"
if self._compiled_regex and self._compiled_regex.search(text):
return True, f"Matched regex '{self.pattern}'{self.description}"
return False, "No match"
if self.rule_type == "function":
if callable(self.function):
try:
matched, reason = self.function(text)
if isinstance(matched, bool) and isinstance(reason, str):
return matched, reason
logger.warning(f"[Rule] '{self.name}' function returned invalid types.")
return False, "Invalid function return type; expected (bool, str)"
except Exception as exc:
logger.exception(f"[Rule] '{self.name}' function raised exception.")
return False, f"Rule function raised exception: {exc!r}"
logger.warning(f"[Rule] '{self.name}' has invalid function configuration.")
return False, "Invalid rule configuration: function not callable"
logger.warning(f"[Rule] '{self.name}' has unknown type '{self.rule_type}'.")
return False, f"Invalid rule configuration: unknown type '{self.rule_type}'"
@dataclass
class RuleResult:
"""
Uniform per-rule outcome for UI/API consumption.
result is "PASS" or "FAIL" (FAIL == matched True)
"""
name: str
description: str
category: str
result: str # "PASS" | "FAIL"
reason: Optional[str] = None
severity: Optional[str] = None
tags: Optional[List[str]] = None
class RuleEngine:
"""
Loads and executes rules against provided text, with Flask-aware logging.
"""
def __init__(self, rules: Optional[List[Rule]] = None, logger: Optional[logging.Logger] = None):
"""
Args:
rules: Optional initial rule list.
logger: Optional explicit logger. If None, uses Flask app logger if available,
otherwise a namespaced standard logger.
"""
if logger is None:
self.logger = get_engine_logger()
else:
self.logger = logger
self.rules: List[Rule] = rules or []
self._compile_all()
def _compile_all(self) -> None:
"""
Compile all regex rules at initialization and warn about failures.
"""
index = 0
total = len(self.rules)
while index < total:
rule = self.rules[index]
if rule.rule_type == "regex":
compiled_ok = rule.compile_if_needed(logger=self.logger)
if not compiled_ok:
self.logger.warning(f"[Engine] Regex failed at init for rule '{rule.name}' (pattern={rule.pattern!r})")
index = index + 1
def add_rule(self, rule: Rule) -> None:
"""
Add a new rule at runtime; compiles regex if needed and logs failures.
"""
self.rules.append(rule)
if rule.rule_type == "regex":
compiled_ok = rule.compile_if_needed(logger=self.logger)
if not compiled_ok:
self.logger.warning(f"[Engine] Regex failed when adding rule '{rule.name}' (pattern={rule.pattern!r})")
def run_all(self, text: str, category: Optional[str] = None) -> List[Dict]:
"""
Run all rules against text.
Args:
text: The content to test.
category: If provided, only evaluate rules that match this category.
Returns:
List of dicts with PASS/FAIL per rule (JSON-serializable).
"""
results: List[Dict] = []
index = 0
total = len(self.rules)
while index < total:
rule = self.rules[index]
if category is not None and rule.category != category:
index = index + 1
continue
matched, reason = rule.run(text, logger=self.logger)
result_str = "FAIL" if matched else "PASS"
reason_to_include: Optional[str]
if matched:
reason_to_include = reason
else:
reason_to_include = None
rr = RuleResult(
name=rule.name,
description=rule.description,
category=rule.category,
result=result_str,
reason=reason_to_include,
severity=rule.severity,
tags=rule.tags,
)
results.append(asdict(rr))
index = index + 1
self.logger.debug(f"[Engine] Completed evaluation. Returned {len(results)} rule results.")
return results
def load_rules_from_yaml(yaml_file: Union[str, Path], logger: Optional[logging.Logger] = None) -> List[Rule]:
"""
Load rules from a YAML file.
Supports optional 'severity' and 'tags' keys.
Example YAML:
- name: suspicious_eval
description: "Use of eval() in script"
category: script
type: regex
pattern: "\\beval\\("
severity: medium
tags: [obfuscation]
Returns:
List[Rule]
"""
if logger is None:
logger = get_engine_logger()
rules: List[Rule] = []
path = Path(yaml_file)
with path.open("r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if not isinstance(data, list):
logger.error("[Loader] Rules YAML must be a list of rule objects.")
raise ValueError("Rules YAML must be a list of rule objects.")
idx = 0
total = len(data)
while idx < total:
item = data[idx]
name = item.get("name")
description = item.get("description")
category = item.get("category")
rule_type = item.get("type", "regex")
pattern = item.get("pattern")
severity = item.get("severity")
tags = item.get("tags")
if not name or not description or not category:
logger.warning(f"[Loader] Skipping invalid rule at index {idx}: missing required fields.")
idx = idx + 1
continue
rule = Rule(
name=name,
description=description,
category=category,
rule_type=rule_type,
pattern=pattern,
function=None, # function rules should be registered in code
severity=severity,
tags=tags if isinstance(tags, list) else None,
)
rules.append(rule)
idx = idx + 1
logger.info(f"[Loader] Loaded {len(rules)} rules from '{yaml_file}'.")
return rules

View File

@@ -63,6 +63,7 @@ class AppConfig:
name: str = "MyApp"
version_major: int = 1
version_minor: int = 0
print_rule_loads: bool = False
@dataclass

270
app/utils/tls_probe.py Normal file
View File

@@ -0,0 +1,270 @@
import socket
import ssl
import time
import logging
from urllib.parse import urlparse
class TLSProbeResult:
"""
Container for the results of a TLS probe across protocol versions.
"""
def __init__(self):
self.hostname = None
self.port = 443
self.results_by_version = {} # e.g., {"TLS1.2": {"supported": True, "cipher": "TLS_AES_128_GCM_SHA256", ...}}
self.weak_protocols = [] # e.g., ["TLS1.0", "TLS1.1"]
self.weak_ciphers = [] # e.g., ["RC4-SHA"]
self.errors = [] # textual errors encountered during probing
def to_dict(self):
"""
Convert the object to a serializable dictionary.
"""
output = {
"hostname": self.hostname,
"port": self.port,
"results_by_version": self.results_by_version,
"weak_protocols": self.weak_protocols,
"weak_ciphers": self.weak_ciphers,
"errors": self.errors
}
return output
class TLSEnumerator:
"""
Enumerate supported TLS versions for a server by attempting handshakes with constrained contexts.
Also collects the server-selected cipher for each successful handshake.
Notes:
- We do NOT validate certificates; this is posture discovery, not trust verification.
- Cipher enumeration is limited to "what was negotiated with default cipher list" per version.
Deep cipher scanning (per-cipher attempts) can be added later if needed.
"""
def __init__(self, timeout_seconds=5.0):
self.timeout_seconds = float(timeout_seconds)
def _build_context_for_version(self, tls_version_label):
"""
Build an SSLContext that only allows the specified TLS version.
"""
# Base client context
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
# Disable certificate checks so we can probe misconfigured/self-signed endpoints
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE
# Constrain to a single protocol version using minimum/maximum
# Map label -> ssl.TLSVersion
if tls_version_label == "TLS1.0" and hasattr(ssl.TLSVersion, "TLSv1"):
context.minimum_version = ssl.TLSVersion.TLSv1
context.maximum_version = ssl.TLSVersion.TLSv1
elif tls_version_label == "TLS1.1" and hasattr(ssl.TLSVersion, "TLSv1_1"):
context.minimum_version = ssl.TLSVersion.TLSv1_1
context.maximum_version = ssl.TLSVersion.TLSv1_1
elif tls_version_label == "TLS1.2" and hasattr(ssl.TLSVersion, "TLSv1_2"):
context.minimum_version = ssl.TLSVersion.TLSv1_2
context.maximum_version = ssl.TLSVersion.TLSv1_2
elif tls_version_label == "TLS1.3" and hasattr(ssl.TLSVersion, "TLSv1_3"):
context.minimum_version = ssl.TLSVersion.TLSv1_3
context.maximum_version = ssl.TLSVersion.TLSv1_3
else:
# Version not supported by this Python/OpenSSL build
return None
# Keep default cipher list; we only want to see what is negotiated
# You can later set context.set_ciphers("...") for deeper scans.
return context
def _attempt_handshake(self, hostname, port, context):
"""
Attempt a TLS handshake to (hostname, port) using the given context.
Returns a tuple: (supported(bool), selected_cipher(str or None), elapsed_seconds(float or None), error(str or None))
"""
supported = False
selected_cipher = None
elapsed = None
error_text = None
# Create a TCP connection with a timeout
sock = None
ssock = None
start = None
try:
# Resolve and connect
# Note: create_connection will handle IPv4/IPv6 resolution
sock = socket.create_connection((hostname, port), timeout=self.timeout_seconds)
# Start timer right before TLS wrap to capture handshake duration mainly
start = time.time()
# SNI is important: pass server_hostname
ssock = context.wrap_socket(sock, server_hostname=hostname)
# Access negotiated cipher; returns (cipher_name, protocol, secret_bits)
cipher_info = ssock.cipher()
if cipher_info is not None and len(cipher_info) >= 1:
selected_cipher = str(cipher_info[0])
supported = True
elapsed = time.time() - start
except Exception as exc:
# Capture the error for diagnostics
error_text = f"{type(exc).__name__}: {str(exc)}"
elapsed = None
finally:
# Clean up sockets
try:
if ssock is not None:
ssock.close()
except Exception:
pass
try:
if sock is not None:
sock.close()
except Exception:
pass
return supported, selected_cipher, elapsed, error_text
def probe(self, target):
"""
Probe the target (URL or hostname or 'hostname:port') for TLS 1.0/1.1/1.2/1.3 support.
Returns TLSProbeResult.
"""
result = TLSProbeResult()
host, port = self._parse_target_to_host_port(target)
result.hostname = host
result.port = port
if host is None:
result.errors.append("Unable to parse a hostname from the target.")
return result
# Define the versions we will test, in ascending order
versions_to_test = ["TLS1.0", "TLS1.1", "TLS1.2", "TLS1.3"]
# Iterate explicitly to match your coding style preference
for version_label in versions_to_test:
context = self._build_context_for_version(version_label)
# If this Python/OpenSSL cannot restrict to this version, mark as unsupported_by_runtime
if context is None:
version_outcome = {
"supported": False,
"selected_cipher": None,
"handshake_seconds": None,
"error": "Version not supported by local runtime"
}
result.results_by_version[version_label] = version_outcome
continue
supported, cipher, elapsed, err = self._attempt_handshake(host, port, context)
version_outcome = {
"supported": supported,
"selected_cipher": cipher,
"handshake_seconds": elapsed,
"error": err
}
result.results_by_version[version_label] = version_outcome
# Determine weak protocols (if the handshake succeeded on legacy versions)
# RFC 8996 and industry guidance deprecate TLS 1.0 and 1.1.
try:
v10 = result.results_by_version.get("TLS1.0")
if v10 is not None and v10.get("supported") is True:
result.weak_protocols.append("TLS1.0")
except Exception:
pass
try:
v11 = result.results_by_version.get("TLS1.1")
if v11 is not None and v11.get("supported") is True:
result.weak_protocols.append("TLS1.1")
except Exception:
pass
# Flag weak ciphers encountered in any successful negotiation
# This is a heuristic: we only see the single chosen cipher per version.
try:
for label in ["TLS1.0", "TLS1.1", "TLS1.2", "TLS1.3"]:
outcome = result.results_by_version.get(label)
if outcome is None:
continue
if outcome.get("supported") is not True:
continue
cipher_name = outcome.get("selected_cipher")
if cipher_name is None:
continue
# Simple string-based checks for known-weak families
# (RC4, 3DES, NULL, EXPORT, MD5). Expand as needed.
name_upper = str(cipher_name).upper()
is_weak = False
if "RC4" in name_upper:
is_weak = True
elif "3DES" in name_upper or "DES-CBC3" in name_upper:
is_weak = True
elif "NULL" in name_upper:
is_weak = True
elif "EXPORT" in name_upper or "EXP-" in name_upper:
is_weak = True
elif "-MD5" in name_upper:
is_weak = True
if is_weak:
# Avoid duplicates
if cipher_name not in result.weak_ciphers:
result.weak_ciphers.append(cipher_name)
except Exception as exc:
result.errors.append(f"Cipher analysis error: {exc}")
return result
def _parse_target_to_host_port(self, target):
"""
Accepts URL, hostname, or 'hostname:port' and returns (hostname, port).
Defaults to port 443 if not specified.
"""
if target is None:
return None, 443
text = str(target).strip()
if text == "":
return None, 443
# If it's clearly a URL, parse it normally
if "://" in text:
parsed = urlparse(text)
hostname = parsed.hostname
port = parsed.port
if hostname is None:
return None, 443
if port is None:
port = 443
return hostname.lower(), int(port)
# If it's host:port, split safely
# Note: URLs without scheme can be tricky (IPv6), but we'll handle [::1]:443 form later if needed
if ":" in text and text.count(":") == 1:
host_part, port_part = text.split(":")
host_part = host_part.strip()
port_part = port_part.strip()
if host_part == "":
return None, 443
try:
port_value = int(port_part)
except Exception:
port_value = 443
return host_part.lower(), int(port_value)
# Otherwise treat it as a bare hostname
return text.lower(), 443

133
app/utils/url_tools.py Normal file
View File

@@ -0,0 +1,133 @@
# app/utils/urltools.py
from urllib.parse import urlparse, urlunparse
import requests
import idna
# Reuse existing decorator (import from wherever you defined it)
from app.utils.settings import singleton_loader
class URLNormalizer:
"""
Normalize user input into a fully-qualified URL for analysis.
Behavior:
- If no scheme is present, prepend https:// by default.
- Optional quick HTTPS reachability check with fallback to http://.
- Converts Unicode hostnames to punycode via IDNA.
Notes:
- Keep the first-constructed configuration stable via the singleton factory.
- Avoids Flask/current_app/threading per your project style.
"""
def __init__(self, prefer_https: bool = True, fallback_http: bool = False, connect_timeout: float = 2.0):
self.prefer_https = bool(prefer_https)
self.fallback_http = bool(fallback_http)
self.connect_timeout = float(connect_timeout)
def normalize_for_analysis(self, raw_input: str) -> str:
"""
Convert raw input (URL or domain) into a normalized URL string.
Raises:
ValueError: if input is empty/invalid.
"""
if raw_input is None:
raise ValueError("Empty input")
text = str(raw_input).strip()
if text == "":
raise ValueError("Empty input")
# Repair common typos (missing colon)
lower = text.lower()
if lower.startswith("http//"):
text = "http://" + text[6:]
elif lower.startswith("https//"):
text = "https://" + text[7:]
# Respect an existing scheme
if "://" in text:
parsed = urlparse(text)
return self._recompose_with_punycode_host(parsed)
# No scheme -> build one
if self.prefer_https:
https_url = "https://" + text
if self.fallback_http:
if self._quick_https_ok(https_url):
return self._recompose_with_punycode_host(urlparse(https_url))
http_url = "http://" + text
return self._recompose_with_punycode_host(urlparse(http_url))
return self._recompose_with_punycode_host(urlparse(https_url))
http_url = "http://" + text
return self._recompose_with_punycode_host(urlparse(http_url))
def _recompose_with_punycode_host(self, parsed):
"""
Recompose a parsed URL with hostname encoded to ASCII (punycode).
Preserves userinfo, port, path, params, query, fragment.
"""
host = parsed.hostname
if host is None:
return urlunparse(parsed)
try:
ascii_host = idna.encode(host).decode("ascii")
except Exception:
ascii_host = host
# rebuild netloc (auth + port)
netloc = ascii_host
if parsed.port:
netloc = f"{netloc}:{parsed.port}"
if parsed.username:
if parsed.password:
netloc = f"{parsed.username}:{parsed.password}@{netloc}"
else:
netloc = f"{parsed.username}@{netloc}"
return urlunparse((
parsed.scheme,
netloc,
parsed.path or "",
parsed.params or "",
parsed.query or "",
parsed.fragment or "",
))
def _quick_https_ok(self, https_url: str) -> bool:
"""
Quick reachability check for https:// using a HEAD request.
Redirects allowed; TLS verify disabled — posture-only.
"""
try:
resp = requests.head(https_url, allow_redirects=True, timeout=self.connect_timeout, verify=False)
_ = resp.status_code
return True
except Exception:
return False
# ---- Singleton factory using our decorator ----
@singleton_loader
def get_url_normalizer(
prefer_https: bool = True,
fallback_http: bool = False,
connect_timeout: float = 2.0,
) -> URLNormalizer:
"""
Return the singleton URLNormalizer instance.
IMPORTANT: With this decorator, the FIRST call's arguments "win".
Later calls return the cached instance and ignore new arguments.
"""
return URLNormalizer(
prefer_https=prefer_https,
fallback_http=fallback_http,
connect_timeout=connect_timeout,
)