Compare commits
46 Commits
d845fa45a3
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
| 6145a23296 | |||
| 16d79df421 | |||
| 1ee721ac10 | |||
| d54a3480b8 | |||
| d3b286ba40 | |||
| d829e6553c | |||
| 2c532adbbc | |||
| be1ea81102 | |||
| 3fe0f7af47 | |||
| 05754fe06b | |||
| 0886727437 | |||
| 638aecb561 | |||
| b878408f3e | |||
| 5b5c3098bb | |||
| 4e3da84578 | |||
| 2ad3df521d | |||
| 4496fce354 | |||
| 133bcbda57 | |||
| 7705008b9c | |||
| 9273d14845 | |||
| f0d8ef8f0a | |||
| 25fa7dc82b | |||
| 220c6613e4 | |||
| 22f10cd8e9 | |||
| 2ae8294e29 | |||
| 26bcbc6c1f | |||
| cc03f76593 | |||
| 90a38f12d1 | |||
| f93b28215d | |||
| 3f9012e6c2 | |||
| 7600195ecf | |||
| 13af7565dd | |||
| bdf7225472 | |||
| a15e428af0 | |||
| 99a15cbd9b | |||
| 202466f73d | |||
| 623ed14cbf | |||
| 641672e4c7 | |||
| cab3fbc1cf | |||
| 021fe340c1 | |||
| 8335978583 | |||
| 08da2c542a | |||
| ff40ef7803 | |||
| 3e88d1d481 | |||
| 76ba490aa2 | |||
| 82846d6236 |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -31,3 +31,9 @@ htmlcov/
|
|||||||
|
|
||||||
# uv
|
# uv
|
||||||
.python-version
|
.python-version
|
||||||
|
|
||||||
|
# Worktrees
|
||||||
|
.worktrees/
|
||||||
|
|
||||||
|
# SneakyCode local data
|
||||||
|
.sneakycode/
|
||||||
|
|||||||
40
.sneakycode/skills/brainstorm/prompt.md
Normal file
40
.sneakycode/skills/brainstorm/prompt.md
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# Brainstorm Skill
|
||||||
|
|
||||||
|
You are in **brainstorming mode**. Your goal is creative ideation — generating multiple approaches, exploring trade-offs, and helping the user think through possibilities before committing to an implementation.
|
||||||
|
|
||||||
|
## Process
|
||||||
|
|
||||||
|
1. **Clarify the goal**: Make sure you understand what the user wants to achieve. Ask clarifying questions if needed.
|
||||||
|
2. **Divergent thinking**: Generate at least 3 distinct approaches. Push beyond the obvious — include creative or unconventional options.
|
||||||
|
3. **Evaluate trade-offs**: For each approach, identify:
|
||||||
|
- Pros and cons
|
||||||
|
- Complexity and effort estimate (low / medium / high)
|
||||||
|
- Risk factors
|
||||||
|
- What it enables or prevents in the future
|
||||||
|
4. **Synthesize**: Recommend your top pick with reasoning, but present all options fairly.
|
||||||
|
5. **Refine**: Ask the user which direction appeals to them and iterate.
|
||||||
|
|
||||||
|
## Guidelines
|
||||||
|
|
||||||
|
- Read relevant code first to ground your suggestions in reality (the explore skill has already run if chained).
|
||||||
|
- Don't just list options — explain *why* each one is interesting or viable.
|
||||||
|
- Be bold. Brainstorming is the place for ambitious ideas.
|
||||||
|
- If the user's initial framing seems limiting, gently challenge it.
|
||||||
|
- Avoid implementation details at this stage — focus on approach and design.
|
||||||
|
|
||||||
|
## Output Format
|
||||||
|
|
||||||
|
Present options as numbered approaches with clear headings:
|
||||||
|
|
||||||
|
### Approach 1: [Name]
|
||||||
|
[Description, pros, cons, complexity]
|
||||||
|
|
||||||
|
### Approach 2: [Name]
|
||||||
|
[Description, pros, cons, complexity]
|
||||||
|
|
||||||
|
### Approach 3: [Name]
|
||||||
|
[Description, pros, cons, complexity]
|
||||||
|
|
||||||
|
**Recommendation**: [Your pick and why]
|
||||||
|
|
||||||
|
When brainstorming is complete and the user has chosen a direction, call `finish_skill` summarizing the chosen approach.
|
||||||
9
.sneakycode/skills/brainstorm/skill.yaml
Normal file
9
.sneakycode/skills/brainstorm/skill.yaml
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
name: brainstorm
|
||||||
|
description: Creative ideation — divergent thinking, option generation, structured exploration
|
||||||
|
version: "1.0"
|
||||||
|
triggers: ["/brainstorm", "/bs"]
|
||||||
|
config_overrides:
|
||||||
|
temperature: 1.2
|
||||||
|
tools_disable: [write_file, make_dir, delete_file, str_replace, patch_apply, run_command]
|
||||||
|
chain: [explore]
|
||||||
|
prompts: [prompt.md]
|
||||||
31
.sneakycode/skills/explore/prompt.md
Normal file
31
.sneakycode/skills/explore/prompt.md
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# Explore Skill
|
||||||
|
|
||||||
|
You are in **exploration mode**. Your goal is to deeply understand the codebase or a specific area of it. Do NOT make any changes — only read, search, and analyze.
|
||||||
|
|
||||||
|
## Approach
|
||||||
|
|
||||||
|
1. **Start broad**: Use `list_dir` and `find_files` to understand the project structure
|
||||||
|
2. **Trace paths**: Follow imports, function calls, and data flow through the code
|
||||||
|
3. **Map relationships**: Identify which files depend on which, and how components interact
|
||||||
|
4. **Read carefully**: Use `read_file` to examine key files in detail
|
||||||
|
5. **Search patterns**: Use `grep_files` to find usage patterns, implementations, and references
|
||||||
|
|
||||||
|
## Output Format
|
||||||
|
|
||||||
|
Produce a structured summary with:
|
||||||
|
|
||||||
|
- **Architecture overview**: High-level description of the system's structure
|
||||||
|
- **Key components**: List of important files/classes and their responsibilities
|
||||||
|
- **Data flow**: How data moves through the system (requests, transformations, storage)
|
||||||
|
- **Dependencies**: Internal and external dependency map
|
||||||
|
- **Patterns**: Design patterns, conventions, and idioms used in the codebase
|
||||||
|
- **Observations**: Anything notable — potential issues, tech debt, clever solutions
|
||||||
|
|
||||||
|
## Guidelines
|
||||||
|
|
||||||
|
- Be thorough but focused. If the user specified an area, concentrate there.
|
||||||
|
- Don't guess — read the actual code before making claims.
|
||||||
|
- Quote specific file paths and line numbers when referencing code.
|
||||||
|
- If you find something unexpected or concerning, flag it clearly.
|
||||||
|
|
||||||
|
When you have completed your exploration, call `finish_skill` with a brief summary of your findings.
|
||||||
9
.sneakycode/skills/explore/skill.yaml
Normal file
9
.sneakycode/skills/explore/skill.yaml
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
name: explore
|
||||||
|
description: Deep codebase exploration — traces paths, maps architecture, summarizes findings
|
||||||
|
version: "1.0"
|
||||||
|
triggers: ["/explore", "/ex"]
|
||||||
|
config_overrides:
|
||||||
|
temperature: 0.3
|
||||||
|
tools_disable: [write_file, make_dir, delete_file, str_replace, patch_apply, run_command]
|
||||||
|
chain: []
|
||||||
|
prompts: [prompt.md]
|
||||||
50
.sneakycode/skills/plan/prompt.md
Normal file
50
.sneakycode/skills/plan/prompt.md
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Plan Skill
|
||||||
|
|
||||||
|
You are in **planning mode**. Your goal is to break down a task into a clear, actionable implementation plan. The explore skill has already run (if chained), so you have codebase context.
|
||||||
|
|
||||||
|
## Process
|
||||||
|
|
||||||
|
1. **Define scope**: Clearly state what the plan covers and what it does not.
|
||||||
|
2. **Decompose**: Break the task into discrete, ordered steps. Each step should be:
|
||||||
|
- Small enough to implement in one focused session
|
||||||
|
- Clear enough that someone unfamiliar could follow it
|
||||||
|
- Testable — you can verify the step was done correctly
|
||||||
|
3. **Identify dependencies**: Note which steps depend on others and the critical path.
|
||||||
|
4. **Map to files**: For each step, list the specific files to create or modify.
|
||||||
|
5. **Flag risks**: Identify anything that could go wrong, require decisions, or block progress.
|
||||||
|
|
||||||
|
## Output Format
|
||||||
|
|
||||||
|
```
|
||||||
|
# Implementation Plan: [Title]
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
[What this covers and what it doesn't]
|
||||||
|
|
||||||
|
## Steps
|
||||||
|
|
||||||
|
### Step 1: [Title]
|
||||||
|
- **Files**: [files to create/modify]
|
||||||
|
- **Description**: [what to do]
|
||||||
|
- **Depends on**: [prior steps, if any]
|
||||||
|
- **Verification**: [how to confirm it's done]
|
||||||
|
|
||||||
|
### Step 2: [Title]
|
||||||
|
...
|
||||||
|
|
||||||
|
## Risks & Open Questions
|
||||||
|
- [Risk or question]
|
||||||
|
|
||||||
|
## Build Order
|
||||||
|
[Recommended sequence, considering dependencies]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Guidelines
|
||||||
|
|
||||||
|
- Be specific — name exact files, functions, and modules.
|
||||||
|
- Keep steps granular. "Implement the backend" is too vague. "Add the /api/users endpoint with GET and POST handlers" is good.
|
||||||
|
- Consider both happy path and error cases in your plan.
|
||||||
|
- If you need to make assumptions, state them explicitly.
|
||||||
|
- Use `run_command` if you need to check project state (e.g., installed packages, running services).
|
||||||
|
|
||||||
|
When the plan is complete and the user has approved it, call `finish_skill` with a one-line summary.
|
||||||
9
.sneakycode/skills/plan/skill.yaml
Normal file
9
.sneakycode/skills/plan/skill.yaml
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
name: plan
|
||||||
|
description: Break down tasks, create roadmaps, plan implementations
|
||||||
|
version: "1.0"
|
||||||
|
triggers: ["/plan"]
|
||||||
|
config_overrides:
|
||||||
|
temperature: 0.5
|
||||||
|
tools_disable: [write_file, make_dir, delete_file, str_replace, patch_apply]
|
||||||
|
chain: [explore]
|
||||||
|
prompts: [prompt.md]
|
||||||
47
.sneakycode/skills/write-document/prompt.md
Normal file
47
.sneakycode/skills/write-document/prompt.md
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Write Document Skill
|
||||||
|
|
||||||
|
You are in **document writing mode**. Your goal is to draft, edit, or improve written documents — READMEs, technical specs, changelogs, guides, or any prose content.
|
||||||
|
|
||||||
|
## Workflow
|
||||||
|
|
||||||
|
### 1. Understand the Document
|
||||||
|
- What type of document? (README, spec, changelog, tutorial, etc.)
|
||||||
|
- Who is the audience? (developers, users, stakeholders)
|
||||||
|
- What is the desired tone? (formal, casual, technical)
|
||||||
|
- Are there existing documents to reference or update?
|
||||||
|
|
||||||
|
### 2. Outline
|
||||||
|
Before writing, propose a structure:
|
||||||
|
- List the main sections
|
||||||
|
- Note what each section should cover
|
||||||
|
- Get user approval on the outline before drafting
|
||||||
|
|
||||||
|
### 3. Draft
|
||||||
|
Write the full document based on the approved outline:
|
||||||
|
- Use clear, concise language
|
||||||
|
- Follow Markdown formatting conventions
|
||||||
|
- Include code examples where appropriate
|
||||||
|
- Be specific — avoid vague statements
|
||||||
|
|
||||||
|
### 4. Revise
|
||||||
|
After the initial draft:
|
||||||
|
- Check for consistency in tone and terminology
|
||||||
|
- Verify technical accuracy by reading referenced code
|
||||||
|
- Ensure all sections from the outline are covered
|
||||||
|
- Trim unnecessary content
|
||||||
|
|
||||||
|
## Document Templates
|
||||||
|
|
||||||
|
**README**: Project name, description, installation, usage, configuration, contributing, license
|
||||||
|
**Technical Spec**: Context, goals, non-goals, design, alternatives considered, implementation plan
|
||||||
|
**Changelog**: Version, date, categories (Added, Changed, Fixed, Removed)
|
||||||
|
**Guide/Tutorial**: Prerequisites, step-by-step instructions, examples, troubleshooting
|
||||||
|
|
||||||
|
## Guidelines
|
||||||
|
|
||||||
|
- Read existing project docs and code to ensure accuracy.
|
||||||
|
- Match the existing documentation style if updating.
|
||||||
|
- Prefer concrete examples over abstract descriptions.
|
||||||
|
- Use the `write_file` tool to save the document when the user approves.
|
||||||
|
|
||||||
|
When the document is complete and saved, call `finish_skill` with a summary of what was written.
|
||||||
8
.sneakycode/skills/write-document/skill.yaml
Normal file
8
.sneakycode/skills/write-document/skill.yaml
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
name: write-document
|
||||||
|
description: Draft and edit documents — READMEs, specs, changelogs, prose
|
||||||
|
version: "1.0"
|
||||||
|
triggers: ["/write-doc", "/doc"]
|
||||||
|
config_overrides:
|
||||||
|
temperature: 0.7
|
||||||
|
chain: []
|
||||||
|
prompts: [prompt.md]
|
||||||
223
README.md
Normal file
223
README.md
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
# SneakyCode
|
||||||
|
|
||||||
|
A privacy-first, locally-running Python coding agent that uses a local LLM (via Ollama) to perform autonomous coding tasks inside a project directory.
|
||||||
|
|
||||||
|
SneakyCode accepts natural language tasks and executes them using a defined toolset for filesystem operations, shell execution, code search, and file manipulation. It runs a ReAct-style tool-call loop: send conversation history to the LLM, receive tool calls, execute them with permission checks, and feed results back until the task is complete.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- Python 3.11+
|
||||||
|
- [Ollama](https://ollama.ai/) running locally with a model that supports function calling (e.g., `qwen3.5`, `llama3.1`, `mistral-nemo`)
|
||||||
|
- [uv](https://docs.astral.sh/uv/) (recommended) or pip
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone the repository
|
||||||
|
git clone <repo-url>
|
||||||
|
cd SneakyCode
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
uv sync --dev
|
||||||
|
|
||||||
|
# Or with pip
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Edit `config/config.yaml` to configure the agent. The full configuration reference:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
llm:
|
||||||
|
model: "qwen3.5:latest" # Ollama model name
|
||||||
|
endpoint: "http://localhost:11434" # Ollama endpoint
|
||||||
|
api_path: "/v1/chat/completions" # API endpoint path
|
||||||
|
temperature: 0.1 # Sampling temperature
|
||||||
|
max_tokens: 4096 # Maximum tokens in LLM response
|
||||||
|
timeout: 120 # Request timeout in seconds
|
||||||
|
max_retries: 3 # Retry attempts on transient errors
|
||||||
|
retry_backoff_base: 1.0 # Exponential backoff base (seconds)
|
||||||
|
retry_backoff_max: 30.0 # Maximum backoff seconds
|
||||||
|
|
||||||
|
agent:
|
||||||
|
max_iterations: 25 # Max tool-call iterations per turn
|
||||||
|
max_conversation_tokens: 32000 # Token budget for conversation
|
||||||
|
workspace_root: "." # Project directory for file operations
|
||||||
|
truncation_keep_recent: 10 # Messages preserved during truncation
|
||||||
|
truncation_threshold: 0.85 # Budget fraction that triggers truncation
|
||||||
|
|
||||||
|
session:
|
||||||
|
session_dir: ".sneakycode/sessions" # Directory for session files
|
||||||
|
auto_save: true # Save session after each turn
|
||||||
|
max_session_age_hours: 72 # Auto-cleanup old sessions
|
||||||
|
offer_resume: true # Offer to resume on startup
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
auto_approve: [read_file, list_dir, grep_files, find_files, finish]
|
||||||
|
prompt_user: [write_file, delete_file, run_command, str_replace, patch_apply, make_dir]
|
||||||
|
deny: []
|
||||||
|
|
||||||
|
tools:
|
||||||
|
shell:
|
||||||
|
allowed_commands: # Commands the LLM may run
|
||||||
|
- git
|
||||||
|
- python
|
||||||
|
- pip
|
||||||
|
- pytest
|
||||||
|
- ruff
|
||||||
|
- ls
|
||||||
|
- cat
|
||||||
|
- head
|
||||||
|
- tail
|
||||||
|
- wc
|
||||||
|
- diff
|
||||||
|
- grep
|
||||||
|
- find
|
||||||
|
- echo
|
||||||
|
denied_commands: # Blocked commands
|
||||||
|
- rm -rf /
|
||||||
|
- sudo
|
||||||
|
- curl
|
||||||
|
- wget
|
||||||
|
max_output_bytes: 65536 # Max captured output size (bytes)
|
||||||
|
filesystem:
|
||||||
|
max_file_size_bytes: 1048576 # 1 MB — max file size for read/write
|
||||||
|
binary_detection: true # Detect and reject binary files
|
||||||
|
|
||||||
|
display:
|
||||||
|
show_tool_calls: true # Show tool call details in output
|
||||||
|
show_token_usage: true # Show token usage stats
|
||||||
|
stream_output: true # Stream LLM output to terminal
|
||||||
|
|
||||||
|
skills:
|
||||||
|
enabled: true # Enable the skills system
|
||||||
|
directories: # Directories to scan for skill files
|
||||||
|
- ".sneakycode/skills"
|
||||||
|
|
||||||
|
debug:
|
||||||
|
enabled: false # Enable debug logging
|
||||||
|
log_dir: ".sneakycode/logs" # Debug log directory
|
||||||
|
max_files: 10 # Max debug log files to retain
|
||||||
|
```
|
||||||
|
|
||||||
|
Environment variable `SNEAKYCODE_CONFIG` can override the config file path.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start the interactive TUI
|
||||||
|
sneakycode
|
||||||
|
|
||||||
|
# Open a specific project directory
|
||||||
|
sneakycode /path/to/project
|
||||||
|
|
||||||
|
# Or run directly
|
||||||
|
python -m app.main
|
||||||
|
|
||||||
|
# With options
|
||||||
|
sneakycode --config path/to/config.yaml --verbose --log-file sneakycode.log
|
||||||
|
```
|
||||||
|
|
||||||
|
### CLI Options
|
||||||
|
|
||||||
|
| Option | Description |
|
||||||
|
|--------------------------|--------------------------------------------------|
|
||||||
|
| `DIRECTORY` | Project directory to use as workspace root |
|
||||||
|
| `--config PATH` | Path to config YAML file (default: `config/config.yaml`) |
|
||||||
|
| `-v`, `--verbose` | Enable verbose (DEBUG) logging |
|
||||||
|
| `--log-file PATH` | Path to log file for persistent logging |
|
||||||
|
|
||||||
|
### REPL Commands
|
||||||
|
|
||||||
|
| Command | Description |
|
||||||
|
|-------------------|----------------------------------------------------|
|
||||||
|
| `/help` | Show available commands |
|
||||||
|
| `/quit` | Save session and exit (also `/exit`, `/bye`) |
|
||||||
|
| `/history` | Show conversation history |
|
||||||
|
| `/clear` | Clear conversation history |
|
||||||
|
| `/save` | Manually save session |
|
||||||
|
| `/session` | Show session info (messages, tokens, start time) |
|
||||||
|
| `/models` | List available Ollama models |
|
||||||
|
| `/models <name>` | Switch to a different model |
|
||||||
|
| `/skills` | List available skills |
|
||||||
|
|
||||||
|
### Session Persistence
|
||||||
|
|
||||||
|
Sessions are automatically saved after each agent turn and on exit. On startup, SneakyCode offers to resume the most recent session for the current workspace.
|
||||||
|
|
||||||
|
Session files are stored in `.sneakycode/sessions/` within the workspace root (configurable via `session.session_dir`).
|
||||||
|
|
||||||
|
## Available Tools
|
||||||
|
|
||||||
|
SneakyCode provides tools across 6 categories. See [docs/tools.md](docs/tools.md) for the full reference.
|
||||||
|
|
||||||
|
| Category | Tools | Permission |
|
||||||
|
|------------|-------------------------------------------------|---------------|
|
||||||
|
| Read | `read_file`, `list_dir` | Auto-approved |
|
||||||
|
| Search | `grep_files`, `find_files` | Auto-approved |
|
||||||
|
| Write | `write_file`, `make_dir`, `delete_file` | User confirm |
|
||||||
|
| Edit | `str_replace`, `patch_apply` | User confirm |
|
||||||
|
| Shell | `run_command` | User confirm |
|
||||||
|
| Control | `finish` | Auto-approved |
|
||||||
|
| Skills | `load_skill` | Auto-approved |
|
||||||
|
|
||||||
|
The `load_skill` tool is available when `skills.enabled` is `true` in the config. It allows the LLM to load skill instructions from the configured skill directories.
|
||||||
|
|
||||||
|
## Skills
|
||||||
|
|
||||||
|
SneakyCode includes a skills system that lets you provide reusable instruction sets to the LLM. Skills are markdown files placed in `.sneakycode/skills/` (or any directory listed in `skills.directories`).
|
||||||
|
|
||||||
|
Skills are auto-discovered on startup. The LLM can load them via the `load_skill` tool, and you can list available skills with the `/skills` command.
|
||||||
|
|
||||||
|
To create a skill, add a `.md` file to your skills directory with a descriptive filename (e.g., `refactoring.md`). The file content is injected into the conversation when the skill is loaded.
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run tests
|
||||||
|
.venv/bin/python -m pytest tests/ -v
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
.venv/bin/python -m pytest tests/ --cov=app
|
||||||
|
|
||||||
|
# Lint
|
||||||
|
.venv/bin/ruff check app/ tests/
|
||||||
|
|
||||||
|
# Format
|
||||||
|
.venv/bin/ruff format app/ tests/
|
||||||
|
```
|
||||||
|
|
||||||
|
### Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
app/
|
||||||
|
├── agent/ # Agent loop and session context
|
||||||
|
├── models/ # Pydantic config and message schemas
|
||||||
|
├── services/ # LLM client, streaming, permissions, session persistence
|
||||||
|
├── tools/ # Tool implementations (one file per group)
|
||||||
|
├── ui/ # Textual TUI application and widgets
|
||||||
|
└── utils/ # Logging, display, file helpers, token counter
|
||||||
|
config/
|
||||||
|
└── config.yaml # Application configuration
|
||||||
|
tests/
|
||||||
|
├── unit/ # Unit tests for individual components
|
||||||
|
└── integration/ # End-to-end workflow tests with mocked LLM
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
SneakyCode follows a **ReAct-style** agent pattern:
|
||||||
|
|
||||||
|
1. User provides a task in natural language
|
||||||
|
2. Agent sends conversation history + tool schemas to the LLM
|
||||||
|
3. LLM responds with either text (task complete) or tool calls
|
||||||
|
4. Agent executes tool calls with permission checks
|
||||||
|
5. Results are fed back to the LLM for the next iteration
|
||||||
|
6. Loop continues until the LLM produces a plain-text response or calls `finish`
|
||||||
|
|
||||||
|
The LLM client is abstracted behind an OpenAI-compatible interface, so any endpoint implementing the `/v1/chat/completions` SSE streaming protocol works as a backend.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
MIT
|
||||||
@@ -77,3 +77,101 @@ class SessionContext:
|
|||||||
def start_time(self) -> datetime:
|
def start_time(self) -> datetime:
|
||||||
"""Session start timestamp (UTC)."""
|
"""Session start timestamp (UTC)."""
|
||||||
return self._start_time
|
return self._start_time
|
||||||
|
|
||||||
|
def truncate_history(self, system_token_estimate: int = 0) -> int:
|
||||||
|
"""Drop oldest messages to bring token usage under budget.
|
||||||
|
|
||||||
|
Preserves the first user message and the most recent N messages
|
||||||
|
(configured by ``truncation_keep_recent``). Cleans up orphaned tool
|
||||||
|
messages after truncation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_token_estimate: Estimated tokens used by the system prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of messages dropped.
|
||||||
|
"""
|
||||||
|
budget = self._token_counter.budget
|
||||||
|
threshold = self._config.agent.truncation_threshold
|
||||||
|
keep_recent = self._config.agent.truncation_keep_recent
|
||||||
|
|
||||||
|
estimated = self._token_counter.estimate_messages_tokens(self._history) + system_token_estimate
|
||||||
|
if estimated < threshold * budget:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
target = int(budget * 0.75) # headroom
|
||||||
|
if len(self._history) <= keep_recent + 1:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Split: first user message | droppable middle | recent tail
|
||||||
|
first_msg = self._history[0] if self._history and self._history[0].role == "user" else None
|
||||||
|
start_idx = 1 if first_msg else 0
|
||||||
|
tail_start = max(start_idx, len(self._history) - keep_recent)
|
||||||
|
|
||||||
|
dropped = 0
|
||||||
|
drop_indices: set[int] = set()
|
||||||
|
|
||||||
|
for i in range(start_idx, tail_start):
|
||||||
|
drop_indices.add(i)
|
||||||
|
dropped += 1
|
||||||
|
# Recalculate with remaining messages
|
||||||
|
remaining = [m for j, m in enumerate(self._history) if j not in drop_indices]
|
||||||
|
est = self._token_counter.estimate_messages_tokens(remaining) + system_token_estimate
|
||||||
|
if est < target:
|
||||||
|
break
|
||||||
|
|
||||||
|
if dropped == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
self._history = [m for j, m in enumerate(self._history) if j not in drop_indices]
|
||||||
|
|
||||||
|
# Clean up orphaned tool messages
|
||||||
|
self._cleanup_orphaned_tool_messages()
|
||||||
|
|
||||||
|
return dropped
|
||||||
|
|
||||||
|
def _cleanup_orphaned_tool_messages(self) -> None:
|
||||||
|
"""Remove tool messages whose tool_call_id doesn't match any assistant tool_call."""
|
||||||
|
# Collect all tool_call IDs from assistant messages
|
||||||
|
valid_tc_ids: set[str] = set()
|
||||||
|
for msg in self._history:
|
||||||
|
if msg.role == "assistant" and msg.tool_calls:
|
||||||
|
for tc in msg.tool_calls:
|
||||||
|
valid_tc_ids.add(tc.id)
|
||||||
|
|
||||||
|
# Remove tool messages referencing missing tool calls
|
||||||
|
self._history = [
|
||||||
|
msg for msg in self._history
|
||||||
|
if msg.role != "tool" or (msg.tool_call_id and msg.tool_call_id in valid_tc_ids)
|
||||||
|
]
|
||||||
|
|
||||||
|
def to_serializable(self) -> dict:
|
||||||
|
"""Export messages and token state for session persistence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with messages and token usage data.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"messages": [m.model_dump(exclude_none=True) for m in self._history],
|
||||||
|
"token_usage": self._token_counter.cumulative_usage.model_dump(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def restore_from(self, data: dict) -> None:
|
||||||
|
"""Clear and replay from serialized data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dict with messages and optional token_usage as produced by to_serializable().
|
||||||
|
"""
|
||||||
|
self._history.clear()
|
||||||
|
self._message_count = 0
|
||||||
|
|
||||||
|
for msg_data in data.get("messages", []):
|
||||||
|
msg = Message(**msg_data)
|
||||||
|
self._history.append(msg)
|
||||||
|
self._message_count += 1
|
||||||
|
|
||||||
|
token_data = data.get("token_usage")
|
||||||
|
if token_data:
|
||||||
|
from app.utils.token_counter import TokenUsage
|
||||||
|
usage = TokenUsage(**token_data)
|
||||||
|
self._token_counter.count_usage(usage)
|
||||||
|
|||||||
@@ -1,26 +1,27 @@
|
|||||||
"""AgentLoop — ReAct-style tool-call loop for autonomous task execution."""
|
"""AgentLoop — ReAct-style tool-call loop for autonomous task execution."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from app.agent.context import SessionContext
|
from app.agent.context import SessionContext
|
||||||
from app.models.config import AppConfig
|
from app.models.config import AgentMode, AppConfig
|
||||||
from app.models.message import Message
|
from app.models.message import Message
|
||||||
from app.models.tool_call import ToolCall, ToolResult, ToolResultStatus
|
from app.models.tool_call import ToolCall, ToolResult, ToolResultStatus
|
||||||
from app.services.llm import LLMClient, LLMConnectionError, LLMError
|
from app.services.llm import LLMClient, LLMConnectionError, LLMError, LLMStreamError
|
||||||
from app.services.permissions import PermissionsService
|
from app.services.permissions import PermissionsService
|
||||||
from app.services.streaming import StreamHandler
|
from app.services.streaming import StreamHandler
|
||||||
from app.tools.registry import ToolRegistry
|
from app.tools.registry import ToolRegistry
|
||||||
from app.utils.display import (
|
from app.utils.display import DisplayAdapter
|
||||||
print_error,
|
|
||||||
print_iteration_header,
|
|
||||||
print_tool_call,
|
|
||||||
print_tool_result,
|
|
||||||
print_token_usage,
|
|
||||||
print_warning,
|
|
||||||
)
|
|
||||||
from app.utils.logging import get_logger
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.debug_log import DebugLogger
|
||||||
|
from app.services.skill_runner import SkillRunner
|
||||||
|
from app.services.skills import SkillsManager
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
_MAX_REASONING_RETRIES = 2
|
_MAX_REASONING_RETRIES = 2
|
||||||
@@ -42,6 +43,10 @@ class AgentLoop:
|
|||||||
handler: StreamHandler,
|
handler: StreamHandler,
|
||||||
registry: ToolRegistry,
|
registry: ToolRegistry,
|
||||||
permissions: PermissionsService,
|
permissions: PermissionsService,
|
||||||
|
display: DisplayAdapter | None = None,
|
||||||
|
debug_logger: DebugLogger | None = None,
|
||||||
|
skills_manager: SkillsManager | None = None,
|
||||||
|
skill_runner: SkillRunner | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._config = config
|
self._config = config
|
||||||
self._ctx = ctx
|
self._ctx = ctx
|
||||||
@@ -49,13 +54,28 @@ class AgentLoop:
|
|||||||
self._handler = handler
|
self._handler = handler
|
||||||
self._registry = registry
|
self._registry = registry
|
||||||
self._permissions = permissions
|
self._permissions = permissions
|
||||||
|
self._display = display
|
||||||
|
self._debug = debug_logger
|
||||||
|
self._skills = skills_manager
|
||||||
|
self._skill_runner = skill_runner
|
||||||
self._tools_schema = registry.get_openai_tools_schema()
|
self._tools_schema = registry.get_openai_tools_schema()
|
||||||
|
if self._permissions.mode == AgentMode.PLAN:
|
||||||
|
read_only = PermissionsService.READ_ONLY_TOOLS
|
||||||
|
self._tools_schema = [
|
||||||
|
t for t in self._tools_schema
|
||||||
|
if t["function"]["name"] in read_only
|
||||||
|
]
|
||||||
self._system_prompt = self._build_system_prompt()
|
self._system_prompt = self._build_system_prompt()
|
||||||
|
self._cancelled = False
|
||||||
|
|
||||||
|
def cancel(self) -> None:
|
||||||
|
"""Request cancellation of the current agent turn."""
|
||||||
|
self._cancelled = True
|
||||||
|
|
||||||
def _build_system_prompt(self) -> str:
|
def _build_system_prompt(self) -> str:
|
||||||
"""Build the system prompt including tool schemas and agent instructions."""
|
"""Build the system prompt including tool schemas and agent instructions."""
|
||||||
tool_names = [t["function"]["name"] for t in self._tools_schema]
|
tool_names = [t["function"]["name"] for t in self._tools_schema]
|
||||||
return (
|
prompt = (
|
||||||
"You are SneakyCode, a local AI coding agent. "
|
"You are SneakyCode, a local AI coding agent. "
|
||||||
"You help users with software engineering tasks by reading files, "
|
"You help users with software engineering tasks by reading files, "
|
||||||
"searching code, and answering questions about their project.\n\n"
|
"searching code, and answering questions about their project.\n\n"
|
||||||
@@ -68,11 +88,54 @@ class AgentLoop:
|
|||||||
"with a brief summary. If you can answer directly without tools, just respond "
|
"with a brief summary. If you can answer directly without tools, just respond "
|
||||||
"with text (no tool call needed)."
|
"with text (no tool call needed)."
|
||||||
)
|
)
|
||||||
|
if self._skills:
|
||||||
|
prompt += self._skills.get_system_prompt_snippet()
|
||||||
|
if self._skill_runner and self._skill_runner.is_active:
|
||||||
|
prompt += (
|
||||||
|
f"\n\nCurrently active skill: {self._skill_runner.active_skill_name}. "
|
||||||
|
"When the skill's objective is complete, call the `finish_skill` tool."
|
||||||
|
)
|
||||||
|
if self._permissions.mode == AgentMode.PLAN:
|
||||||
|
prompt += (
|
||||||
|
"\n\nYou are in PLAN mode. You may only use read-only tools: "
|
||||||
|
"read_file, list_dir, grep_files, find_files, finish. "
|
||||||
|
"Do NOT attempt to write files, edit code, or run commands. "
|
||||||
|
"Instead, describe what changes you would make, which files "
|
||||||
|
"you would modify, and provide the reasoning for each change."
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
# Models whose chat templates understand /no_think directives.
|
||||||
|
_THINKING_MODEL_PREFIXES = ("qwen", "qwq")
|
||||||
|
|
||||||
|
def _model_supports_no_think(self) -> bool:
|
||||||
|
"""Check if the current model uses a thinking chat template."""
|
||||||
|
model_lower = self._config.llm.model.lower()
|
||||||
|
return any(model_lower.startswith(p) for p in self._THINKING_MODEL_PREFIXES)
|
||||||
|
|
||||||
def _get_messages_with_system_prompt(self) -> list[Message]:
|
def _get_messages_with_system_prompt(self) -> list[Message]:
|
||||||
"""Prepend the system prompt to conversation history."""
|
"""Prepend the system prompt to conversation history.
|
||||||
|
|
||||||
|
When thinking is disabled on a model that supports it, appends a
|
||||||
|
system-level /no_think directive after the last user message so
|
||||||
|
Qwen 3.x (and similar) chat templates see it.
|
||||||
|
"""
|
||||||
system_msg = Message(role="system", content=self._system_prompt)
|
system_msg = Message(role="system", content=self._system_prompt)
|
||||||
return [system_msg] + self._ctx.get_history()
|
history = self._ctx.get_history()
|
||||||
|
|
||||||
|
if not self._config.llm.thinking and self._model_supports_no_think() and history:
|
||||||
|
history = list(history)
|
||||||
|
# Find last user message and insert a system hint after it
|
||||||
|
for i in range(len(history) - 1, -1, -1):
|
||||||
|
if history[i].role == "user":
|
||||||
|
no_think_msg = Message(
|
||||||
|
role="system",
|
||||||
|
content="/no_think",
|
||||||
|
)
|
||||||
|
history.insert(i + 1, no_think_msg)
|
||||||
|
break
|
||||||
|
|
||||||
|
return [system_msg] + history
|
||||||
|
|
||||||
async def run_turn(self, user_input: str) -> None:
|
async def run_turn(self, user_input: str) -> None:
|
||||||
"""Execute one full agent turn: add user message, loop until done.
|
"""Execute one full agent turn: add user message, loop until done.
|
||||||
@@ -81,17 +144,32 @@ class AgentLoop:
|
|||||||
user_input: The user's message text.
|
user_input: The user's message text.
|
||||||
"""
|
"""
|
||||||
self._ctx.add_message("user", user_input)
|
self._ctx.add_message("user", user_input)
|
||||||
|
self._cancelled = False
|
||||||
|
|
||||||
max_iter = self._config.agent.max_iterations
|
max_iter = self._config.agent.max_iterations
|
||||||
reasoning_only_streak = 0
|
reasoning_only_streak = 0
|
||||||
|
empty_streak = 0
|
||||||
for iteration in range(1, max_iter + 1):
|
for iteration in range(1, max_iter + 1):
|
||||||
# Check token budget
|
if self._cancelled:
|
||||||
if self._ctx.token_counter.is_over_budget():
|
if self._display:
|
||||||
print_warning("Token budget exceeded. Stopping agent loop.")
|
self._display.write_warning("Agent loop cancelled.")
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Check token budget — try truncation before giving up
|
||||||
|
if self._ctx.token_counter.is_over_budget():
|
||||||
|
system_tokens = self._ctx.token_counter.estimate_tokens(self._system_prompt)
|
||||||
|
dropped = self._ctx.truncate_history(system_tokens)
|
||||||
|
if dropped > 0:
|
||||||
|
if self._display:
|
||||||
|
self._display.write_warning(f"Token budget pressure: dropped {dropped} oldest messages.")
|
||||||
|
else:
|
||||||
|
if self._display:
|
||||||
|
self._display.write_warning("Token budget exceeded, cannot truncate further. Stopping.")
|
||||||
|
break
|
||||||
|
|
||||||
if iteration > 1:
|
if iteration > 1:
|
||||||
print_iteration_header(iteration, max_iter)
|
if self._display:
|
||||||
|
self._display.write_iteration_header(iteration, max_iter)
|
||||||
|
|
||||||
# Stream LLM response
|
# Stream LLM response
|
||||||
assistant_msg = await self._llm_step()
|
assistant_msg = await self._llm_step()
|
||||||
@@ -112,11 +190,11 @@ class AgentLoop:
|
|||||||
if self._handler.usage:
|
if self._handler.usage:
|
||||||
self._ctx.token_counter.count_usage(self._handler.usage)
|
self._ctx.token_counter.count_usage(self._handler.usage)
|
||||||
|
|
||||||
if self._config.display.show_token_usage:
|
if self._config.display.show_token_usage and self._display:
|
||||||
total = self._ctx.token_counter.cumulative_usage.total_tokens
|
total = self._ctx.token_counter.cumulative_usage.total_tokens
|
||||||
if total == 0:
|
if total == 0:
|
||||||
total = self._ctx.estimated_tokens
|
total = self._ctx.estimated_tokens
|
||||||
print_token_usage(total, self._ctx.token_counter.budget)
|
self._display.write_token_usage(total, self._ctx.token_counter.budget)
|
||||||
|
|
||||||
self._handler.reset()
|
self._handler.reset()
|
||||||
|
|
||||||
@@ -125,30 +203,82 @@ class AgentLoop:
|
|||||||
reasoning_only_streak += 1
|
reasoning_only_streak += 1
|
||||||
self._ctx.pop_last_message()
|
self._ctx.pop_last_message()
|
||||||
|
|
||||||
if reasoning_only_streak >= _MAX_REASONING_RETRIES:
|
# When thinking is disabled, reasoning-only is expected model noise.
|
||||||
# Nudge the model by injecting a user hint
|
# Nudge immediately and silently to avoid wasting iterations.
|
||||||
print_warning(
|
thinking_disabled = not self._config.llm.thinking
|
||||||
f"Model produced reasoning but no response {reasoning_only_streak} times. "
|
|
||||||
"Nudging model to respond..."
|
# If the last context messages are tool errors, nudge immediately
|
||||||
)
|
# rather than wasting retries — the model is likely confused by the error.
|
||||||
|
has_recent_tool_error = any(
|
||||||
|
m.role == "tool" and m.content and m.content.startswith("Unknown ")
|
||||||
|
for m in self._ctx.get_history()[-3:]
|
||||||
|
)
|
||||||
|
|
||||||
|
should_nudge = (
|
||||||
|
thinking_disabled
|
||||||
|
or has_recent_tool_error
|
||||||
|
or reasoning_only_streak >= _MAX_REASONING_RETRIES
|
||||||
|
)
|
||||||
|
|
||||||
|
if should_nudge:
|
||||||
|
if not thinking_disabled and self._display:
|
||||||
|
self._display.write_warning(
|
||||||
|
f"Model produced reasoning but no response {reasoning_only_streak} times. "
|
||||||
|
"Nudging model to respond..."
|
||||||
|
)
|
||||||
self._ctx.add_message(
|
self._ctx.add_message(
|
||||||
"user",
|
"user",
|
||||||
"Please respond with your answer. Do not just think — provide your actual response.",
|
"Please respond with your answer. If a tool call failed, briefly explain what happened and continue.",
|
||||||
)
|
)
|
||||||
reasoning_only_streak = 0
|
reasoning_only_streak = 0
|
||||||
else:
|
else:
|
||||||
print_warning("Model produced reasoning but no response. Retrying...")
|
if self._display:
|
||||||
|
self._display.write_warning("Model produced reasoning but no response. Retrying...")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Successful response — reset streak
|
# Successful response — reset streak
|
||||||
reasoning_only_streak = 0
|
reasoning_only_streak = 0
|
||||||
|
|
||||||
|
# Detect completely empty response (no content, no tool calls)
|
||||||
|
if not assistant_msg.content and not assistant_msg.tool_calls:
|
||||||
|
empty_streak += 1
|
||||||
|
self._ctx.pop_last_message() # Don't keep empty messages
|
||||||
|
if empty_streak >= 2:
|
||||||
|
if self._display:
|
||||||
|
self._display.write_warning(
|
||||||
|
"Model returned repeated empty responses — "
|
||||||
|
"try a different model or check Ollama logs."
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if self._display:
|
||||||
|
self._display.write_warning("Model returned empty response. Retrying without tools...")
|
||||||
|
# Retry without tool schemas — some models return empty when
|
||||||
|
# tools are in the payload but the model can't handle them.
|
||||||
|
assistant_msg = await self._llm_step(skip_tools=True)
|
||||||
|
if assistant_msg is None:
|
||||||
|
break
|
||||||
|
if assistant_msg.content:
|
||||||
|
self._ctx.add_message("assistant", assistant_msg.content)
|
||||||
|
if self._display:
|
||||||
|
self._display.write_assistant_message(assistant_msg.content)
|
||||||
|
self._handler.reset()
|
||||||
|
break
|
||||||
|
# Still empty even without tools
|
||||||
|
self._handler.reset()
|
||||||
|
continue
|
||||||
|
|
||||||
|
empty_streak = 0 # reset on successful non-empty response
|
||||||
|
|
||||||
|
# Display any assistant text content (even if tool calls follow)
|
||||||
|
if self._display and assistant_msg.content:
|
||||||
|
self._display.write_assistant_message(assistant_msg.content)
|
||||||
|
|
||||||
# No tool calls → task complete (plain text response)
|
# No tool calls → task complete (plain text response)
|
||||||
if not assistant_msg.tool_calls:
|
if not assistant_msg.tool_calls:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Execute tool calls
|
# Execute tool calls
|
||||||
results = self._execute_tool_calls(assistant_msg.tool_calls)
|
results = await self._execute_tool_calls(assistant_msg.tool_calls)
|
||||||
|
|
||||||
# Add tool results to context
|
# Add tool results to context
|
||||||
for result in results:
|
for result in results:
|
||||||
@@ -160,34 +290,62 @@ class AgentLoop:
|
|||||||
name=result.tool_name,
|
name=result.tool_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if finish tool was called
|
# Rebuild tools schema and system prompt if skill state may have changed
|
||||||
|
if any(r.tool_name in ("load_skill", "finish_skill") for r in results):
|
||||||
|
self._tools_schema = self._registry.get_openai_tools_schema()
|
||||||
|
self._system_prompt = self._build_system_prompt()
|
||||||
|
|
||||||
|
# Check if finish tool was called (finish_skill does NOT break the loop)
|
||||||
if any(r.tool_name == "finish" for r in results):
|
if any(r.tool_name == "finish" for r in results):
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
print_warning(f"Agent reached maximum iterations ({max_iter}). Stopping.")
|
if self._display:
|
||||||
|
self._display.write_warning(f"Agent reached maximum iterations ({max_iter}). Stopping.")
|
||||||
|
|
||||||
async def _llm_step(self) -> Message | None:
|
async def _llm_step(self, *, skip_tools: bool = False) -> Message | None:
|
||||||
"""Stream one LLM response and return the accumulated Message.
|
"""Stream one LLM response and return the accumulated Message.
|
||||||
|
|
||||||
|
Uses retry-enabled streaming. On mid-stream errors, attempts to recover
|
||||||
|
partial content if available.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skip_tools: If True, send the request without tool schemas (fallback mode).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The assistant Message, or None if an error occurred.
|
The assistant Message, or None if an error occurred.
|
||||||
"""
|
"""
|
||||||
messages = self._get_messages_with_system_prompt()
|
messages = self._get_messages_with_system_prompt()
|
||||||
|
if self._debug:
|
||||||
|
self._debug.log_request(messages, self._config.llm.model)
|
||||||
|
tools = None if skip_tools else self._tools_schema
|
||||||
|
t0 = time.monotonic()
|
||||||
try:
|
try:
|
||||||
chunk_iter = self._client.stream_chat(messages, tools=self._tools_schema)
|
chunk_iter = self._client.stream_chat_with_retry(messages, tools=tools)
|
||||||
return await self._handler.process_stream(chunk_iter)
|
result = await self._handler.process_stream(chunk_iter)
|
||||||
|
if result and self._debug:
|
||||||
|
elapsed = (time.monotonic() - t0) * 1000
|
||||||
|
self._debug.log_response(result, self._handler.usage, elapsed)
|
||||||
|
return result
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print_warning("Response interrupted.")
|
if self._display:
|
||||||
|
self._display.write_warning("Response interrupted.")
|
||||||
self._handler.reset()
|
self._handler.reset()
|
||||||
return None
|
return None
|
||||||
except LLMConnectionError as e:
|
except (LLMConnectionError, LLMStreamError) as e:
|
||||||
print_error(f"Connection error: {e}")
|
partial = self._handler.get_partial_message()
|
||||||
|
if partial is not None:
|
||||||
|
if self._display:
|
||||||
|
self._display.write_warning(f"Stream interrupted ({e}), returning partial response.")
|
||||||
|
return partial
|
||||||
|
if self._display:
|
||||||
|
self._display.write_error(f"Connection error: {e}")
|
||||||
return None
|
return None
|
||||||
except LLMError as e:
|
except LLMError as e:
|
||||||
print_error(f"LLM error: {e}")
|
if self._display:
|
||||||
|
self._display.write_error(f"LLM error: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _execute_tool_calls(self, tool_calls: list[ToolCall]) -> list[ToolResult]:
|
async def _execute_tool_calls(self, tool_calls: list[ToolCall]) -> list[ToolResult]:
|
||||||
"""Execute a list of tool calls with permission checks.
|
"""Execute a list of tool calls with permission checks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -204,8 +362,8 @@ class AgentLoop:
|
|||||||
tc_id = tc.id
|
tc_id = tc.id
|
||||||
|
|
||||||
# Display the tool call
|
# Display the tool call
|
||||||
if self._config.display.show_tool_calls:
|
if self._config.display.show_tool_calls and self._display:
|
||||||
print_tool_call(name, tc.function.arguments)
|
self._display.write_tool_call(name, tc.function.arguments)
|
||||||
|
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
try:
|
try:
|
||||||
@@ -218,8 +376,8 @@ class AgentLoop:
|
|||||||
error=f"Invalid JSON in arguments: {e}",
|
error=f"Invalid JSON in arguments: {e}",
|
||||||
)
|
)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
if self._config.display.show_tool_calls:
|
if self._config.display.show_tool_calls and self._display:
|
||||||
print_tool_result(name, result.error or "", is_error=True)
|
self._display.write_tool_result(name, result.error or "", is_error=True)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Look up tool
|
# Look up tool
|
||||||
@@ -232,13 +390,13 @@ class AgentLoop:
|
|||||||
error=f"Unknown tool '{name}'. Available: {available_names}",
|
error=f"Unknown tool '{name}'. Available: {available_names}",
|
||||||
)
|
)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
if self._config.display.show_tool_calls:
|
if self._config.display.show_tool_calls and self._display:
|
||||||
print_tool_result(name, result.error or "", is_error=True)
|
self._display.write_tool_result(name, result.error or "", is_error=True)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check permissions (truncate args for display in prompt)
|
# Check permissions (truncate args for display in prompt)
|
||||||
desc = tc.function.arguments[:120] + "..." if len(tc.function.arguments) > 120 else tc.function.arguments
|
desc = tc.function.arguments[:120] + "..." if len(tc.function.arguments) > 120 else tc.function.arguments
|
||||||
if not self._permissions.check(name, description=desc):
|
if not await self._permissions.check(name, description=desc, arguments=tc.function.arguments):
|
||||||
result = ToolResult(
|
result = ToolResult(
|
||||||
tool_call_id=tc_id,
|
tool_call_id=tc_id,
|
||||||
tool_name=name,
|
tool_name=name,
|
||||||
@@ -246,17 +404,22 @@ class AgentLoop:
|
|||||||
error=f"Permission denied for tool '{name}'",
|
error=f"Permission denied for tool '{name}'",
|
||||||
)
|
)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
if self._config.display.show_tool_calls:
|
if self._config.display.show_tool_calls and self._display:
|
||||||
print_tool_result(name, result.error or "", is_error=True)
|
self._display.write_tool_result(name, result.error or "", is_error=True)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Execute tool (BaseTool.run never raises)
|
# Execute tool (BaseTool.run never raises)
|
||||||
|
tool_t0 = time.monotonic()
|
||||||
result = tool.run(tc_id, parsed_args)
|
result = tool.run(tc_id, parsed_args)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
if self._config.display.show_tool_calls:
|
if self._debug:
|
||||||
|
tool_elapsed = (time.monotonic() - tool_t0) * 1000
|
||||||
|
self._debug.log_tool_execution(name, result.status.value, tool_elapsed)
|
||||||
|
|
||||||
|
if self._config.display.show_tool_calls and self._display:
|
||||||
is_error = result.status == ToolResultStatus.ERROR
|
is_error = result.status == ToolResultStatus.ERROR
|
||||||
output = result.error if is_error else result.output
|
output = result.error if is_error else result.output
|
||||||
print_tool_result(name, output or "", is_error=is_error)
|
self._display.write_tool_result(name, output or "", is_error=is_error)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|||||||
116
app/main.py
116
app/main.py
@@ -1,37 +1,19 @@
|
|||||||
"""SneakyCode entrypoint — argument parsing, config loading, and interactive REPL."""
|
"""SneakyCode entrypoint — argument parsing, config loading, and TUI launch."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import structlog
|
|
||||||
|
|
||||||
from app.agent.context import SessionContext
|
|
||||||
from app.agent.loop import AgentLoop
|
|
||||||
from app.models.config import AppConfig, load_config
|
from app.models.config import AppConfig, load_config
|
||||||
from app.services.llm import LLMClient, LLMConnectionError, LLMError
|
from app.services.llm import LLMClient, LLMConnectionError, LLMError
|
||||||
from app.services.permissions import PermissionsService
|
from app.services.session import SessionManager
|
||||||
from app.services.streaming import StreamHandler
|
from app.utils.display import print_banner, print_error, print_info, print_success
|
||||||
from app.tools.registry import create_default_registry
|
from app.utils.logging import get_logger, setup_logging
|
||||||
from app.utils.display import (
|
|
||||||
print_banner,
|
|
||||||
print_error,
|
|
||||||
print_history,
|
|
||||||
print_info,
|
|
||||||
print_success,
|
|
||||||
print_user_message,
|
|
||||||
print_warning,
|
|
||||||
)
|
|
||||||
from app.utils.logging import console, get_logger, setup_logging
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
"""Parse command-line arguments.
|
"""Parse command-line arguments."""
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parsed arguments namespace.
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog="sneakycode",
|
prog="sneakycode",
|
||||||
description="SneakyCode — A privacy-first local AI coding agent",
|
description="SneakyCode — A privacy-first local AI coding agent",
|
||||||
@@ -54,6 +36,13 @@ def parse_args() -> argparse.Namespace:
|
|||||||
default=None,
|
default=None,
|
||||||
help="Path to log file for persistent logging",
|
help="Path to log file for persistent logging",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"directory",
|
||||||
|
nargs="?",
|
||||||
|
type=Path,
|
||||||
|
default=None,
|
||||||
|
help="Project directory to use as workspace root (default: current directory)",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -63,61 +52,11 @@ async def _preflight(config: AppConfig) -> None:
|
|||||||
await client.preflight_check()
|
await client.preflight_check()
|
||||||
|
|
||||||
|
|
||||||
async def _run_repl(
|
|
||||||
ctx: SessionContext,
|
|
||||||
config: AppConfig,
|
|
||||||
logger: structlog.stdlib.BoundLogger,
|
|
||||||
) -> None:
|
|
||||||
"""Run the interactive REPL loop with streaming LLM responses.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ctx: Session context for conversation state.
|
|
||||||
config: Application configuration.
|
|
||||||
logger: Structured logger instance.
|
|
||||||
"""
|
|
||||||
registry = create_default_registry(config.agent.workspace_root, config)
|
|
||||||
permissions = PermissionsService(config.permissions)
|
|
||||||
|
|
||||||
async with LLMClient(config.llm) as client:
|
|
||||||
handler = StreamHandler(config.display)
|
|
||||||
agent = AgentLoop(config, ctx, client, handler, registry, permissions)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
user_input = console.input("[bold cyan]> [/bold cyan]")
|
|
||||||
except (KeyboardInterrupt, EOFError):
|
|
||||||
console.print("\n[dim]Goodbye![/dim]")
|
|
||||||
break
|
|
||||||
|
|
||||||
user_input = user_input.strip()
|
|
||||||
if not user_input:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle slash commands
|
|
||||||
if user_input.startswith("/"):
|
|
||||||
command = user_input.lower()
|
|
||||||
if command == "/quit":
|
|
||||||
console.print("[dim]Goodbye![/dim]")
|
|
||||||
break
|
|
||||||
elif command == "/history":
|
|
||||||
print_history(ctx.get_history())
|
|
||||||
elif command == "/clear":
|
|
||||||
ctx.clear_history()
|
|
||||||
print_success("Conversation history cleared.")
|
|
||||||
else:
|
|
||||||
print_warning(f"Unknown command: {user_input}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
print_user_message(user_input)
|
|
||||||
await agent.run_turn(user_input)
|
|
||||||
logger.debug("turn_complete", message_count=ctx.message_count)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
"""Main entrypoint: load config, setup logging, launch interactive REPL."""
|
"""Main entrypoint: load config, preflight check, launch Textual TUI."""
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
# Setup logging first
|
# Setup logging first (will be reconfigured for TUI on mount)
|
||||||
setup_logging(
|
setup_logging(
|
||||||
log_file=args.log_file,
|
log_file=args.log_file,
|
||||||
verbose=args.verbose,
|
verbose=args.verbose,
|
||||||
@@ -131,9 +70,17 @@ def main() -> None:
|
|||||||
print_error(f"Configuration error: {e}")
|
print_error(f"Configuration error: {e}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Override workspace root if directory argument provided
|
||||||
|
if args.directory:
|
||||||
|
target = Path(args.directory).resolve()
|
||||||
|
if not target.is_dir():
|
||||||
|
print_error(f"Not a directory: {target}")
|
||||||
|
sys.exit(1)
|
||||||
|
config.agent.workspace_root = target
|
||||||
|
|
||||||
logger.info("config_loaded", model=config.llm.model, endpoint=config.llm.endpoint)
|
logger.info("config_loaded", model=config.llm.model, endpoint=config.llm.endpoint)
|
||||||
|
|
||||||
# Print startup info
|
# Pre-TUI startup info (printed to console before Textual takes over)
|
||||||
print_banner()
|
print_banner()
|
||||||
print_info(f"Model: {config.llm.model}")
|
print_info(f"Model: {config.llm.model}")
|
||||||
print_info(f"Endpoint: {config.llm.endpoint}")
|
print_info(f"Endpoint: {config.llm.endpoint}")
|
||||||
@@ -154,12 +101,19 @@ def main() -> None:
|
|||||||
|
|
||||||
print_success("Ollama connected, model ready.")
|
print_success("Ollama connected, model ready.")
|
||||||
|
|
||||||
# Create session and start REPL
|
# Create session manager
|
||||||
ctx = SessionContext(config)
|
session_mgr = SessionManager(config.session, config.agent.workspace_root, config.llm.model)
|
||||||
logger.info("startup_complete")
|
|
||||||
|
|
||||||
print_info("Commands: /quit, /history, /clear")
|
# Clean up old session files
|
||||||
asyncio.run(_run_repl(ctx, config, logger))
|
cleaned = session_mgr.cleanup_old()
|
||||||
|
if cleaned > 0:
|
||||||
|
logger.info("old_sessions_cleaned", count=cleaned)
|
||||||
|
|
||||||
|
# Launch Textual TUI
|
||||||
|
from app.ui.app import SneakyCodeApp
|
||||||
|
|
||||||
|
app = SneakyCodeApp(config, session_mgr=session_mgr)
|
||||||
|
app.run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,12 +1,39 @@
|
|||||||
"""Pydantic configuration models mapping to config/config.yaml."""
|
"""Pydantic configuration models mapping to config/config.yaml."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from enum import StrEnum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
|
||||||
|
class AgentMode(StrEnum):
|
||||||
|
"""Runtime agent mode controlling permission behavior."""
|
||||||
|
|
||||||
|
NORMAL = "normal"
|
||||||
|
PLAN = "plan"
|
||||||
|
AUTO = "auto"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProfile(BaseModel):
|
||||||
|
"""Per-model overrides applied when switching models."""
|
||||||
|
|
||||||
|
max_conversation_tokens: int | None = Field(
|
||||||
|
default=None, description="Token budget override for this model's context window"
|
||||||
|
)
|
||||||
|
thinking: bool | None = Field(
|
||||||
|
default=None, description="Override thinking mode for this model"
|
||||||
|
)
|
||||||
|
temperature: float | None = Field(
|
||||||
|
default=None, description="Override sampling temperature"
|
||||||
|
)
|
||||||
|
max_tokens: int | None = Field(
|
||||||
|
default=None, description="Override max response tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LLMConfig(BaseModel):
|
class LLMConfig(BaseModel):
|
||||||
"""LLM backend configuration."""
|
"""LLM backend configuration."""
|
||||||
|
|
||||||
@@ -16,6 +43,17 @@ class LLMConfig(BaseModel):
|
|||||||
temperature: float = Field(default=0.1, description="Sampling temperature")
|
temperature: float = Field(default=0.1, description="Sampling temperature")
|
||||||
max_tokens: int = Field(default=4096, description="Maximum tokens in LLM response")
|
max_tokens: int = Field(default=4096, description="Maximum tokens in LLM response")
|
||||||
timeout: int = Field(default=120, description="Request timeout in seconds")
|
timeout: int = Field(default=120, description="Request timeout in seconds")
|
||||||
|
max_retries: int = Field(default=3, description="Max retry attempts on transient errors")
|
||||||
|
retry_backoff_base: float = Field(default=1.0, description="Base seconds for exponential backoff")
|
||||||
|
retry_backoff_max: float = Field(default=30.0, description="Maximum backoff seconds")
|
||||||
|
thinking: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Enable model thinking/reasoning mode (disable to reduce reasoning-only loops)",
|
||||||
|
)
|
||||||
|
extra_body: dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Extra parameters merged into the API request body (model-specific)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AgentConfig(BaseModel):
|
class AgentConfig(BaseModel):
|
||||||
@@ -28,6 +66,12 @@ class AgentConfig(BaseModel):
|
|||||||
workspace_root: Path = Field(
|
workspace_root: Path = Field(
|
||||||
default=Path("."), description="Root directory for file operations"
|
default=Path("."), description="Root directory for file operations"
|
||||||
)
|
)
|
||||||
|
truncation_keep_recent: int = Field(
|
||||||
|
default=10, description="Number of recent messages to preserve during truncation"
|
||||||
|
)
|
||||||
|
truncation_threshold: float = Field(
|
||||||
|
default=0.85, description="Token budget fraction that triggers truncation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PermissionsConfig(BaseModel):
|
class PermissionsConfig(BaseModel):
|
||||||
@@ -46,11 +90,19 @@ class ShellToolConfig(BaseModel):
|
|||||||
max_output_bytes: int = Field(default=65536, description="Max output capture size in bytes")
|
max_output_bytes: int = Field(default=65536, description="Max output capture size in bytes")
|
||||||
|
|
||||||
|
|
||||||
|
class FileCacheConfig(BaseModel):
|
||||||
|
"""File cache configuration."""
|
||||||
|
|
||||||
|
enabled: bool = Field(default=True, description="Enable file content caching")
|
||||||
|
max_entries: int = Field(default=128, description="Maximum cached file entries (LRU eviction)")
|
||||||
|
|
||||||
|
|
||||||
class FilesystemToolConfig(BaseModel):
|
class FilesystemToolConfig(BaseModel):
|
||||||
"""Filesystem tool limits."""
|
"""Filesystem tool limits."""
|
||||||
|
|
||||||
max_file_size_bytes: int = Field(default=1_048_576, description="Max file size for read/write")
|
max_file_size_bytes: int = Field(default=1_048_576, description="Max file size for read/write")
|
||||||
binary_detection: bool = Field(default=True, description="Detect and reject binary files")
|
binary_detection: bool = Field(default=True, description="Detect and reject binary files")
|
||||||
|
cache: FileCacheConfig = Field(default_factory=FileCacheConfig, description="File cache settings")
|
||||||
|
|
||||||
|
|
||||||
class ToolsConfig(BaseModel):
|
class ToolsConfig(BaseModel):
|
||||||
@@ -60,6 +112,19 @@ class ToolsConfig(BaseModel):
|
|||||||
filesystem: FilesystemToolConfig = Field(default_factory=FilesystemToolConfig)
|
filesystem: FilesystemToolConfig = Field(default_factory=FilesystemToolConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionConfig(BaseModel):
|
||||||
|
"""Session persistence configuration."""
|
||||||
|
|
||||||
|
session_dir: Path = Field(
|
||||||
|
default=Path(".sneakycode/sessions"), description="Directory for session files"
|
||||||
|
)
|
||||||
|
auto_save: bool = Field(default=True, description="Auto-save session after each turn")
|
||||||
|
max_session_age_hours: int = Field(
|
||||||
|
default=72, description="Max age in hours before session files are cleaned up"
|
||||||
|
)
|
||||||
|
offer_resume: bool = Field(default=True, description="Offer to resume previous sessions on startup")
|
||||||
|
|
||||||
|
|
||||||
class DisplayConfig(BaseModel):
|
class DisplayConfig(BaseModel):
|
||||||
"""Terminal display preferences."""
|
"""Terminal display preferences."""
|
||||||
|
|
||||||
@@ -68,6 +133,24 @@ class DisplayConfig(BaseModel):
|
|||||||
stream_output: bool = Field(default=True, description="Stream LLM output to terminal")
|
stream_output: bool = Field(default=True, description="Stream LLM output to terminal")
|
||||||
|
|
||||||
|
|
||||||
|
class SkillsConfig(BaseModel):
|
||||||
|
"""Skills system configuration."""
|
||||||
|
|
||||||
|
enabled: bool = Field(default=True, description="Enable skills system")
|
||||||
|
directories: list[Path] = Field(
|
||||||
|
default_factory=lambda: [Path(".sneakycode/skills")],
|
||||||
|
description="Directories to scan for skill markdown files",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DebugConfig(BaseModel):
|
||||||
|
"""Debug logging configuration."""
|
||||||
|
|
||||||
|
enabled: bool = Field(default=False, description="Enable debug logging")
|
||||||
|
log_dir: Path = Field(default=Path(".sneakycode/logs"), description="Debug log directory")
|
||||||
|
max_files: int = Field(default=10, description="Max debug log files to retain")
|
||||||
|
|
||||||
|
|
||||||
class AppConfig(BaseModel):
|
class AppConfig(BaseModel):
|
||||||
"""Top-level application configuration composing all sub-configs."""
|
"""Top-level application configuration composing all sub-configs."""
|
||||||
|
|
||||||
@@ -76,6 +159,13 @@ class AppConfig(BaseModel):
|
|||||||
permissions: PermissionsConfig = Field(default_factory=PermissionsConfig)
|
permissions: PermissionsConfig = Field(default_factory=PermissionsConfig)
|
||||||
tools: ToolsConfig = Field(default_factory=ToolsConfig)
|
tools: ToolsConfig = Field(default_factory=ToolsConfig)
|
||||||
display: DisplayConfig = Field(default_factory=DisplayConfig)
|
display: DisplayConfig = Field(default_factory=DisplayConfig)
|
||||||
|
session: SessionConfig = Field(default_factory=SessionConfig)
|
||||||
|
debug: DebugConfig = Field(default_factory=DebugConfig)
|
||||||
|
skills: SkillsConfig = Field(default_factory=SkillsConfig)
|
||||||
|
model_profiles: dict[str, ModelProfile] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description="Per-model overrides keyed by model name prefix",
|
||||||
|
)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def resolve_workspace_root(self) -> "AppConfig":
|
def resolve_workspace_root(self) -> "AppConfig":
|
||||||
@@ -83,6 +173,39 @@ class AppConfig(BaseModel):
|
|||||||
self.agent.workspace_root = self.agent.workspace_root.resolve()
|
self.agent.workspace_root = self.agent.workspace_root.resolve()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def get_model_profile(self, model: str) -> ModelProfile | None:
|
||||||
|
"""Find the best matching model profile by prefix.
|
||||||
|
|
||||||
|
Matches the longest prefix first (e.g., "llama3.1" beats "llama3"
|
||||||
|
for model "llama3.1:latest"). Returns None if no profile matches.
|
||||||
|
"""
|
||||||
|
model_lower = model.lower().split(":")[0] # strip tag
|
||||||
|
best_match: str | None = None
|
||||||
|
for key in self.model_profiles:
|
||||||
|
key_lower = key.lower()
|
||||||
|
if model_lower == key_lower or model_lower.startswith(key_lower):
|
||||||
|
if best_match is None or len(key) > len(best_match):
|
||||||
|
best_match = key
|
||||||
|
return self.model_profiles.get(best_match) if best_match else None
|
||||||
|
|
||||||
|
def apply_model_profile(self, model: str) -> ModelProfile | None:
|
||||||
|
"""Apply the matching model profile overrides to the active config.
|
||||||
|
|
||||||
|
Returns the applied profile, or None if no profile matched.
|
||||||
|
"""
|
||||||
|
profile = self.get_model_profile(model)
|
||||||
|
if profile is None:
|
||||||
|
return None
|
||||||
|
if profile.max_conversation_tokens is not None:
|
||||||
|
self.agent.max_conversation_tokens = profile.max_conversation_tokens
|
||||||
|
if profile.thinking is not None:
|
||||||
|
self.llm.thinking = profile.thinking
|
||||||
|
if profile.temperature is not None:
|
||||||
|
self.llm.temperature = profile.temperature
|
||||||
|
if profile.max_tokens is not None:
|
||||||
|
self.llm.max_tokens = profile.max_tokens
|
||||||
|
return profile
|
||||||
|
|
||||||
|
|
||||||
# Default config file location relative to project root
|
# Default config file location relative to project root
|
||||||
_DEFAULT_CONFIG_PATH = Path("config/config.yaml")
|
_DEFAULT_CONFIG_PATH = Path("config/config.yaml")
|
||||||
|
|||||||
39
app/models/skill.py
Normal file
39
app/models/skill.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""Pydantic models for structured skill packages."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SkillConfigOverrides(BaseModel):
|
||||||
|
"""Scoped config overrides applied while a skill is active."""
|
||||||
|
|
||||||
|
temperature: float | None = Field(default=None, description="Override sampling temperature")
|
||||||
|
max_tokens: int | None = Field(default=None, description="Override max tokens")
|
||||||
|
tools_enable: list[str] | None = Field(
|
||||||
|
default=None, description="Whitelist — only these tools available when set"
|
||||||
|
)
|
||||||
|
tools_disable: list[str] | None = Field(
|
||||||
|
default=None, description="Blacklist — disable specific tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SkillManifest(BaseModel):
|
||||||
|
"""Parsed skill.yaml manifest for a skill package directory."""
|
||||||
|
|
||||||
|
name: str = Field(description="Unique skill identifier")
|
||||||
|
description: str = Field(description="Human-readable skill description")
|
||||||
|
version: str = Field(default="1.0", description="Skill version")
|
||||||
|
triggers: list[str] = Field(
|
||||||
|
default_factory=list, description="Slash commands that activate this skill"
|
||||||
|
)
|
||||||
|
config_overrides: SkillConfigOverrides = Field(
|
||||||
|
default_factory=SkillConfigOverrides, description="Scoped config overrides"
|
||||||
|
)
|
||||||
|
chain: list[str] = Field(
|
||||||
|
default_factory=list, description="Skill names to run first (dependencies)"
|
||||||
|
)
|
||||||
|
prompts: list[str] = Field(
|
||||||
|
default_factory=lambda: ["prompt.md"],
|
||||||
|
description="Markdown prompt files to load, in order",
|
||||||
|
)
|
||||||
77
app/services/debug_log.py
Normal file
77
app/services/debug_log.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""Debug logger — writes detailed LLM interaction logs to JSONL files."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.models.message import Message
|
||||||
|
|
||||||
|
|
||||||
|
class DebugLogger:
|
||||||
|
"""Writes detailed LLM interaction logs to JSONL files for debugging."""
|
||||||
|
|
||||||
|
def __init__(self, log_dir: Path, max_files: int = 10) -> None:
|
||||||
|
self._log_dir = log_dir
|
||||||
|
self._log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._max_files = max_files
|
||||||
|
self._file = self._log_dir / f"debug_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}.jsonl"
|
||||||
|
self._rotate()
|
||||||
|
|
||||||
|
def log_request(self, messages: list[Message], model: str) -> None:
|
||||||
|
"""Log outbound LLM request (message roles/lengths, model)."""
|
||||||
|
self._write({
|
||||||
|
"event": "llm_request",
|
||||||
|
"model": model,
|
||||||
|
"message_count": len(messages),
|
||||||
|
"messages": [
|
||||||
|
{"role": m.role, "content_len": len(m.content or "")}
|
||||||
|
for m in messages
|
||||||
|
],
|
||||||
|
})
|
||||||
|
|
||||||
|
def log_response(
|
||||||
|
self,
|
||||||
|
message: Message,
|
||||||
|
usage: Any | None,
|
||||||
|
elapsed_ms: float,
|
||||||
|
) -> None:
|
||||||
|
"""Log LLM response with timing and token counts."""
|
||||||
|
self._write({
|
||||||
|
"event": "llm_response",
|
||||||
|
"elapsed_ms": round(elapsed_ms, 1),
|
||||||
|
"content_len": len(message.content or ""),
|
||||||
|
"tool_call_count": len(message.tool_calls or []),
|
||||||
|
"tool_calls": [tc.function.name for tc in (message.tool_calls or [])],
|
||||||
|
"usage": usage.__dict__ if usage else None,
|
||||||
|
})
|
||||||
|
|
||||||
|
def log_tool_execution(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
result_status: str,
|
||||||
|
elapsed_ms: float,
|
||||||
|
) -> None:
|
||||||
|
"""Log tool execution with timing."""
|
||||||
|
self._write({
|
||||||
|
"event": "tool_execution",
|
||||||
|
"tool": tool_name,
|
||||||
|
"status": result_status,
|
||||||
|
"elapsed_ms": round(elapsed_ms, 1),
|
||||||
|
})
|
||||||
|
|
||||||
|
def _write(self, record: dict[str, Any]) -> None:
|
||||||
|
record["timestamp"] = datetime.now(UTC).isoformat()
|
||||||
|
with open(self._file, "a") as f:
|
||||||
|
f.write(json.dumps(record) + "\n")
|
||||||
|
|
||||||
|
def _rotate(self) -> None:
|
||||||
|
"""Remove old debug log files beyond max_files."""
|
||||||
|
files = sorted(
|
||||||
|
self._log_dir.glob("debug_*.jsonl"),
|
||||||
|
key=lambda p: p.stat().st_mtime,
|
||||||
|
)
|
||||||
|
while len(files) > self._max_files:
|
||||||
|
files.pop(0).unlink()
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
"""LLM client wrapper for Ollama / OpenAI-compatible endpoints."""
|
"""LLM client wrapper for Ollama / OpenAI-compatible endpoints."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import random
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any, Self
|
from typing import Any, Self
|
||||||
|
|
||||||
@@ -58,6 +60,25 @@ class LLMClient:
|
|||||||
timeout=httpx.Timeout(config.timeout, connect=10.0),
|
timeout=httpx.Timeout(config.timeout, connect=10.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def list_models(self) -> list[dict[str, str]]:
|
||||||
|
"""Query Ollama /api/tags for available models.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with 'name' and 'size' keys.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LLMConnectionError: If the endpoint is unreachable.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await self._client.get("/api/tags")
|
||||||
|
data = response.json()
|
||||||
|
return [
|
||||||
|
{"name": m.get("name", ""), "size": str(m.get("size", ""))}
|
||||||
|
for m in data.get("models", [])
|
||||||
|
]
|
||||||
|
except (httpx.HTTPError, httpx.TimeoutException) as e:
|
||||||
|
raise LLMConnectionError(f"Failed to list models: {e}") from e
|
||||||
|
|
||||||
async def preflight_check(self) -> None:
|
async def preflight_check(self) -> None:
|
||||||
"""Verify the endpoint is reachable and the configured model is available.
|
"""Verify the endpoint is reachable and the configured model is available.
|
||||||
|
|
||||||
@@ -130,6 +151,15 @@ class LLMClient:
|
|||||||
if tools:
|
if tools:
|
||||||
payload["tools"] = tools
|
payload["tools"] = tools
|
||||||
|
|
||||||
|
# When thinking is disabled, inject chat_template_kwargs for backends
|
||||||
|
# that support it (Qwen 3.x thinking models).
|
||||||
|
if not self._config.thinking and self._config.model.lower().startswith(("qwen", "qwq")):
|
||||||
|
payload.setdefault("chat_template_kwargs", {})["enable_thinking"] = False
|
||||||
|
|
||||||
|
# Merge model-specific extra parameters (e.g., reasoning_effort)
|
||||||
|
if self._config.extra_body:
|
||||||
|
payload.update(self._config.extra_body)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._client.stream(
|
async with self._client.stream(
|
||||||
"POST", self._config.api_path, json=payload
|
"POST", self._config.api_path, json=payload
|
||||||
@@ -141,20 +171,32 @@ class LLMClient:
|
|||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
chunk_count = 0
|
||||||
async for line in response.aiter_lines():
|
async for line in response.aiter_lines():
|
||||||
if not line.startswith("data: "):
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
data = line[6:] # strip "data: " prefix
|
# SSE format: "data: {json}" or "data: [DONE]"
|
||||||
|
if line.startswith("data: "):
|
||||||
if data.strip() == "[DONE]":
|
data = line[6:]
|
||||||
return
|
if data.strip() == "[DONE]":
|
||||||
|
break
|
||||||
|
elif line.startswith("{"):
|
||||||
|
# Plain NDJSON fallback (some Ollama versions)
|
||||||
|
data = line
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield json.loads(data)
|
yield json.loads(data)
|
||||||
|
chunk_count += 1
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.warning("malformed_sse_chunk", data=data[:200])
|
logger.warning("malformed_sse_chunk", data=data[:200])
|
||||||
|
|
||||||
|
if chunk_count == 0:
|
||||||
|
logger.warning("empty_stream", model=self._config.model)
|
||||||
|
|
||||||
except httpx.ConnectError as e:
|
except httpx.ConnectError as e:
|
||||||
raise LLMConnectionError(f"Cannot connect to LLM endpoint: {e}") from e
|
raise LLMConnectionError(f"Cannot connect to LLM endpoint: {e}") from e
|
||||||
except httpx.TimeoutException as e:
|
except httpx.TimeoutException as e:
|
||||||
@@ -162,6 +204,61 @@ class LLMClient:
|
|||||||
except httpx.HTTPError as e:
|
except httpx.HTTPError as e:
|
||||||
raise LLMError(f"HTTP error communicating with LLM: {e}") from e
|
raise LLMError(f"HTTP error communicating with LLM: {e}") from e
|
||||||
|
|
||||||
|
async def stream_chat_with_retry(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> AsyncIterator[dict]:
|
||||||
|
"""Stream chat with automatic retry on transient errors.
|
||||||
|
|
||||||
|
Retries on LLMConnectionError and LLMResponseError with status >= 500.
|
||||||
|
Does NOT retry on 4xx errors (client-side, not transient).
|
||||||
|
Uses exponential backoff with jitter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Conversation history to send to the model.
|
||||||
|
tools: Optional OpenAI function-calling tool schemas.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Parsed JSON dicts from each SSE data line.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
LLMConnectionError: After exhausting retries on connection failures.
|
||||||
|
LLMResponseError: After exhausting retries on server errors, or immediately on 4xx.
|
||||||
|
"""
|
||||||
|
max_retries = self._config.max_retries
|
||||||
|
last_exception: LLMError | None = None
|
||||||
|
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
async for chunk in self.stream_chat(messages, tools=tools):
|
||||||
|
yield chunk
|
||||||
|
return
|
||||||
|
except LLMConnectionError as e:
|
||||||
|
last_exception = e
|
||||||
|
except LLMResponseError as e:
|
||||||
|
if e.status_code is not None and e.status_code < 500:
|
||||||
|
raise
|
||||||
|
last_exception = e
|
||||||
|
except LLMStreamError as e:
|
||||||
|
last_exception = e
|
||||||
|
|
||||||
|
if attempt < max_retries:
|
||||||
|
backoff = min(
|
||||||
|
self._config.retry_backoff_base * (2 ** attempt) + random.uniform(0, 1),
|
||||||
|
self._config.retry_backoff_max,
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"llm_retry",
|
||||||
|
attempt=attempt + 1,
|
||||||
|
max_retries=max_retries,
|
||||||
|
backoff_seconds=round(backoff, 2),
|
||||||
|
error=str(last_exception),
|
||||||
|
)
|
||||||
|
await asyncio.sleep(backoff)
|
||||||
|
|
||||||
|
raise last_exception # type: ignore[misc]
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
"""Close the underlying HTTP client."""
|
"""Close the underlying HTTP client."""
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
|
|||||||
@@ -1,27 +1,79 @@
|
|||||||
"""Permission gating for tool execution."""
|
"""Permission gating for tool execution."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
import shlex
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
from rich.prompt import Confirm
|
from app.models.config import AgentMode, PermissionsConfig, ToolsConfig
|
||||||
|
|
||||||
from app.models.config import PermissionsConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Type alias for the async prompt callback
|
||||||
|
PromptCallback = Callable[[str, str], Awaitable[bool]]
|
||||||
|
|
||||||
|
# Detect shell redirects that write to files (>, >>, heredocs)
|
||||||
|
_WRITE_REDIRECT_PATTERN = re.compile(r"(?:>\s*\S|>>|<<)")
|
||||||
|
|
||||||
|
|
||||||
class PermissionDenied(Exception):
|
class PermissionDenied(Exception):
|
||||||
"""Raised when a tool is denied execution by permissions policy."""
|
"""Raised when a tool is denied execution by permissions policy."""
|
||||||
|
|
||||||
|
|
||||||
class PermissionsService:
|
class PermissionsService:
|
||||||
"""Check whether a tool is allowed to execute based on config tiers."""
|
"""Check whether a tool is allowed to execute based on config tiers.
|
||||||
|
|
||||||
def __init__(self, config: PermissionsConfig) -> None:
|
In TUI mode, set a prompt callback via set_prompt_callback() that
|
||||||
|
shows a modal dialog. Without a callback, unlisted tools are denied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
READ_ONLY_TOOLS: frozenset[str] = frozenset({
|
||||||
|
"read_file", "list_dir", "grep_files", "find_files", "finish",
|
||||||
|
})
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PermissionsConfig,
|
||||||
|
tools_config: ToolsConfig | None = None,
|
||||||
|
) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self._tools_config = tools_config
|
||||||
|
self._prompt_callback: PromptCallback | None = None
|
||||||
|
self._mode: AgentMode = AgentMode.NORMAL
|
||||||
|
|
||||||
def check(self, tool_name: str, description: str = "") -> bool:
|
@property
|
||||||
|
def mode(self) -> AgentMode:
|
||||||
|
"""Current agent mode."""
|
||||||
|
return self._mode
|
||||||
|
|
||||||
|
@mode.setter
|
||||||
|
def mode(self, value: AgentMode) -> None:
|
||||||
|
self._mode = value
|
||||||
|
|
||||||
|
def set_prompt_callback(self, callback: PromptCallback) -> None:
|
||||||
|
"""Set the async callback used to prompt the user for permission.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: Async function(tool_name, description) -> bool.
|
||||||
|
"""
|
||||||
|
self._prompt_callback = callback
|
||||||
|
|
||||||
|
async def check(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
description: str = "",
|
||||||
|
arguments: str = "",
|
||||||
|
) -> bool:
|
||||||
"""Check if a tool is permitted to run.
|
"""Check if a tool is permitted to run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool to check.
|
||||||
|
description: Human-readable description for the prompt.
|
||||||
|
arguments: Raw JSON arguments string (used for shell-aware checks).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if permitted, False if denied.
|
True if permitted, False if denied.
|
||||||
"""
|
"""
|
||||||
@@ -29,18 +81,73 @@ class PermissionsService:
|
|||||||
logger.info("Tool '%s' is in deny list — blocked", tool_name)
|
logger.info("Tool '%s' is in deny list — blocked", tool_name)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if self._mode == AgentMode.AUTO:
|
||||||
|
logger.debug("Tool '%s' auto-approved (AUTO mode)", tool_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self._mode == AgentMode.PLAN:
|
||||||
|
if tool_name not in self.READ_ONLY_TOOLS:
|
||||||
|
logger.info("Tool '%s' blocked in Plan mode (read-only tools only)", tool_name)
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
if tool_name in self.config.auto_approve:
|
if tool_name in self.config.auto_approve:
|
||||||
logger.debug("Tool '%s' is auto-approved", tool_name)
|
logger.debug("Tool '%s' is auto-approved", tool_name)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Explicit prompt_user list or unlisted tools both trigger a prompt
|
# Shell-aware: check allowed/denied command lists
|
||||||
return self._prompt_user(tool_name, description)
|
if tool_name == "run_command" and self._tools_config is not None:
|
||||||
|
result = self._check_shell_command(arguments)
|
||||||
|
if result is not None:
|
||||||
|
return result
|
||||||
|
|
||||||
def _prompt_user(self, tool_name: str, description: str) -> bool:
|
# Prompt user via callback (TUI modal, etc.)
|
||||||
"""Prompt the user for approval via the terminal."""
|
if self._prompt_callback is not None:
|
||||||
prompt_text = f"Allow tool [bold]{tool_name}[/bold]"
|
return await self._prompt_callback(tool_name, description)
|
||||||
if description:
|
|
||||||
prompt_text += f" — {description}"
|
|
||||||
prompt_text += "?"
|
|
||||||
|
|
||||||
return Confirm.ask(prompt_text, default=False)
|
# No callback set — deny by default (safe fallback)
|
||||||
|
logger.warning("Tool '%s' requires approval but no prompt callback set — denied", tool_name)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_shell_command(self, arguments: str) -> bool | None:
|
||||||
|
"""Check shell command against allowed/denied lists.
|
||||||
|
|
||||||
|
Returns True (allow), False (deny), or None (fall through to prompt).
|
||||||
|
"""
|
||||||
|
shell_config = self._tools_config.shell # type: ignore[union-attr]
|
||||||
|
|
||||||
|
try:
|
||||||
|
cmd = json.loads(arguments).get("command", "")
|
||||||
|
except (json.JSONDecodeError, AttributeError):
|
||||||
|
return None # can't parse, fall through to prompt
|
||||||
|
|
||||||
|
try:
|
||||||
|
base_cmd = shlex.split(cmd)[0]
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Denied commands: prefix match on full command string
|
||||||
|
for denied in shell_config.denied_commands:
|
||||||
|
if cmd.startswith(denied):
|
||||||
|
logger.info("Shell command '%s' matches denied prefix '%s'", cmd, denied)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Detect shell redirects that write to files — require approval
|
||||||
|
if _WRITE_REDIRECT_PATTERN.search(cmd):
|
||||||
|
logger.info("Shell command '%s' contains file-write redirect — requiring approval", cmd)
|
||||||
|
return None # fall through to user prompt
|
||||||
|
|
||||||
|
# Allowed commands: base executable match
|
||||||
|
if shell_config.allowed_commands:
|
||||||
|
if base_cmd in shell_config.allowed_commands:
|
||||||
|
logger.debug(
|
||||||
|
"Shell command '%s' auto-approved (base '%s' in allowed list)",
|
||||||
|
cmd,
|
||||||
|
base_cmd,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
# Base command NOT in allowed list — fall through to prompt
|
||||||
|
return None
|
||||||
|
|
||||||
|
# No allowed list configured — fall through to prompt
|
||||||
|
return None
|
||||||
|
|||||||
152
app/services/session.py
Normal file
152
app/services/session.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
"""Session persistence — auto-save and restore conversation state."""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.models.config import SessionConfig
|
||||||
|
from app.utils.logging import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.agent.context import SessionContext
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionData(BaseModel):
|
||||||
|
"""Serialized session state for persistence."""
|
||||||
|
|
||||||
|
version: int = Field(default=1, description="Schema version for forward compatibility")
|
||||||
|
session_id: str = Field(description="Unique session identifier")
|
||||||
|
created_at: str = Field(description="ISO timestamp of session creation")
|
||||||
|
updated_at: str = Field(description="ISO timestamp of last update")
|
||||||
|
model: str = Field(description="LLM model name used in session")
|
||||||
|
workspace_root: str = Field(description="Workspace root path")
|
||||||
|
messages: list[dict] = Field(default_factory=list, description="Serialized messages")
|
||||||
|
token_usage: dict = Field(default_factory=dict, description="Cumulative token usage")
|
||||||
|
|
||||||
|
|
||||||
|
class SessionManager:
|
||||||
|
"""Manages session file I/O: save, load, restore, and cleanup.
|
||||||
|
|
||||||
|
Session files are keyed by a hash of the workspace root path so that
|
||||||
|
each project directory has its own session history.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: SessionConfig, workspace_root: Path, model: str) -> None:
|
||||||
|
"""Initialize session manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Session configuration.
|
||||||
|
workspace_root: Absolute path to workspace root.
|
||||||
|
model: LLM model name for session metadata.
|
||||||
|
"""
|
||||||
|
self._config = config
|
||||||
|
self._workspace_root = workspace_root
|
||||||
|
self._model = model
|
||||||
|
self._workspace_hash = hashlib.sha256(str(workspace_root).encode()).hexdigest()[:12]
|
||||||
|
self._session_dir = workspace_root / config.session_dir
|
||||||
|
self._session_id = f"{self._workspace_hash}_{datetime.now(UTC).strftime('%Y%m%d_%H%M%S')}"
|
||||||
|
|
||||||
|
def update_model(self, model: str) -> None:
|
||||||
|
"""Update the model name for session metadata."""
|
||||||
|
self._model = model
|
||||||
|
|
||||||
|
def save(self, ctx: "SessionContext") -> Path:
|
||||||
|
"""Save session state to a JSON file via atomic write.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx: Session context to persist.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the saved session file.
|
||||||
|
"""
|
||||||
|
self._session_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
serialized = ctx.to_serializable()
|
||||||
|
data = SessionData(
|
||||||
|
session_id=self._session_id,
|
||||||
|
created_at=ctx.start_time.isoformat(),
|
||||||
|
updated_at=datetime.now(UTC).isoformat(),
|
||||||
|
model=self._model,
|
||||||
|
workspace_root=str(self._workspace_root),
|
||||||
|
messages=serialized["messages"],
|
||||||
|
token_usage=serialized["token_usage"],
|
||||||
|
)
|
||||||
|
|
||||||
|
file_path = self._session_dir / f"{self._session_id}.json"
|
||||||
|
tmp_path = file_path.with_suffix(".tmp")
|
||||||
|
|
||||||
|
tmp_path.write_text(data.model_dump_json(indent=2), encoding="utf-8")
|
||||||
|
tmp_path.rename(file_path)
|
||||||
|
|
||||||
|
logger.debug("session_saved", path=str(file_path))
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
def load_latest(self) -> SessionData | None:
|
||||||
|
"""Find and load the newest session file for this workspace.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SessionData if a valid session is found, None otherwise.
|
||||||
|
"""
|
||||||
|
if not self._session_dir.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
session_files = sorted(
|
||||||
|
self._session_dir.glob(f"{self._workspace_hash}_*.json"),
|
||||||
|
key=lambda p: p.stat().st_mtime,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for path in session_files:
|
||||||
|
try:
|
||||||
|
raw = json.loads(path.read_text(encoding="utf-8"))
|
||||||
|
return SessionData(**raw)
|
||||||
|
except (json.JSONDecodeError, ValueError, OSError) as e:
|
||||||
|
logger.warning("session_load_error", path=str(path), error=str(e))
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def restore(self, data: SessionData, ctx: "SessionContext") -> None:
|
||||||
|
"""Replay session data into a SessionContext.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Saved session data to restore.
|
||||||
|
ctx: Session context to populate.
|
||||||
|
"""
|
||||||
|
ctx.restore_from({
|
||||||
|
"messages": data.messages,
|
||||||
|
"token_usage": data.token_usage,
|
||||||
|
})
|
||||||
|
# Preserve the original session ID for continuity
|
||||||
|
self._session_id = data.session_id
|
||||||
|
logger.info("session_restored", session_id=data.session_id, messages=len(data.messages))
|
||||||
|
|
||||||
|
def cleanup_old(self) -> int:
|
||||||
|
"""Delete session files older than max_session_age_hours.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of files deleted.
|
||||||
|
"""
|
||||||
|
if not self._session_dir.exists():
|
||||||
|
return 0
|
||||||
|
|
||||||
|
cutoff = datetime.now(UTC).timestamp() - (self._config.max_session_age_hours * 3600)
|
||||||
|
deleted = 0
|
||||||
|
|
||||||
|
for path in self._session_dir.glob("*.json"):
|
||||||
|
try:
|
||||||
|
if path.stat().st_mtime < cutoff:
|
||||||
|
path.unlink()
|
||||||
|
deleted += 1
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if deleted > 0:
|
||||||
|
logger.info("sessions_cleaned", deleted=deleted)
|
||||||
|
return deleted
|
||||||
234
app/services/skill_runner.py
Normal file
234
app/services/skill_runner.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
"""SkillRunner — orchestrates skill activation, chaining, config scoping, and deactivation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from app.agent.context import SessionContext
|
||||||
|
from app.models.config import AppConfig
|
||||||
|
from app.models.skill import SkillManifest
|
||||||
|
from app.services.skills import Skill, SkillsManager
|
||||||
|
from app.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SkillChainError(Exception):
|
||||||
|
"""Raised when skill chain resolution fails (e.g., cycle detected)."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _SkillSnapshot:
|
||||||
|
"""Captured state before skill activation, for restoration on deactivate."""
|
||||||
|
|
||||||
|
temperature: float
|
||||||
|
max_tokens: int
|
||||||
|
disabled_tools: set[str] = field(default_factory=set)
|
||||||
|
|
||||||
|
|
||||||
|
class SkillRunner:
|
||||||
|
"""Manages skill lifecycle: activation, chaining, config overrides, deactivation.
|
||||||
|
|
||||||
|
Only one skill can be active at a time. Activating a new skill while one
|
||||||
|
is active will first deactivate the current skill.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
skills_manager: SkillsManager,
|
||||||
|
config: AppConfig,
|
||||||
|
ctx: SessionContext,
|
||||||
|
registry: ToolRegistry,
|
||||||
|
) -> None:
|
||||||
|
self._skills = skills_manager
|
||||||
|
self._config = config
|
||||||
|
self._ctx = ctx
|
||||||
|
self._registry = registry
|
||||||
|
self._active_skill: Skill | None = None
|
||||||
|
self._snapshot: _SkillSnapshot | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_active(self) -> bool:
|
||||||
|
"""Whether a skill is currently active."""
|
||||||
|
return self._active_skill is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_skill_name(self) -> str | None:
|
||||||
|
"""Name of the currently active skill, or None."""
|
||||||
|
return self._active_skill.name if self._active_skill else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_skill(self) -> Skill | None:
|
||||||
|
"""The currently active skill, or None."""
|
||||||
|
return self._active_skill
|
||||||
|
|
||||||
|
def activate(self, skill_name: str) -> str | None:
|
||||||
|
"""Activate a skill by name.
|
||||||
|
|
||||||
|
Resolves chain dependencies (depth-first), applies config overrides,
|
||||||
|
injects prompt content into conversation context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_name: Name of the skill to activate.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The concatenated prompt content injected, or None on failure.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SkillChainError: If chain resolution detects a cycle.
|
||||||
|
"""
|
||||||
|
skill = self._skills.get_skill(skill_name)
|
||||||
|
if skill is None:
|
||||||
|
logger.warning("Cannot activate unknown skill: %s", skill_name)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Deactivate current skill if one is active
|
||||||
|
if self._active_skill is not None:
|
||||||
|
self.deactivate()
|
||||||
|
|
||||||
|
# Resolve chain dependencies
|
||||||
|
chain = self._resolve_chain(skill, set())
|
||||||
|
|
||||||
|
# Snapshot current config for restoration
|
||||||
|
self._snapshot = _SkillSnapshot(
|
||||||
|
temperature=self._config.llm.temperature,
|
||||||
|
max_tokens=self._config.llm.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect and inject chain skill prompts first
|
||||||
|
all_prompts: list[str] = []
|
||||||
|
for chained_skill in chain:
|
||||||
|
content = self._skills.load_skill(chained_skill.name)
|
||||||
|
if content:
|
||||||
|
all_prompts.append(f"[Chained skill: {chained_skill.name}]\n{content}")
|
||||||
|
|
||||||
|
# Load the target skill's prompts
|
||||||
|
content = self._skills.load_skill(skill.name)
|
||||||
|
if content:
|
||||||
|
all_prompts.append(content)
|
||||||
|
|
||||||
|
# Apply config overrides from the target skill
|
||||||
|
if skill.manifest:
|
||||||
|
self._apply_overrides(skill.manifest)
|
||||||
|
|
||||||
|
# Inject prompts into context
|
||||||
|
full_prompt = "\n\n".join(all_prompts) if all_prompts else None
|
||||||
|
if full_prompt:
|
||||||
|
self._ctx.add_message(
|
||||||
|
"system",
|
||||||
|
f"[Skill activated: {skill.name}]\n{full_prompt}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._active_skill = skill
|
||||||
|
logger.info("Skill activated: %s", skill.name)
|
||||||
|
return full_prompt
|
||||||
|
|
||||||
|
def activate_by_trigger(self, trigger: str) -> str | None:
|
||||||
|
"""Activate a skill by its /command trigger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trigger: The trigger string (with or without leading /).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The concatenated prompt content, or None if no skill matches.
|
||||||
|
"""
|
||||||
|
skill = self._skills.get_skill_by_trigger(trigger)
|
||||||
|
if skill is None:
|
||||||
|
return None
|
||||||
|
return self.activate(skill.name)
|
||||||
|
|
||||||
|
def deactivate(self, summary: str | None = None) -> None:
|
||||||
|
"""Deactivate the current skill, restoring config and tool state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
summary: Optional summary message to inject into context.
|
||||||
|
"""
|
||||||
|
if self._active_skill is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
skill_name = self._active_skill.name
|
||||||
|
|
||||||
|
# Restore config
|
||||||
|
if self._snapshot is not None:
|
||||||
|
self._config.llm.temperature = self._snapshot.temperature
|
||||||
|
self._config.llm.max_tokens = self._snapshot.max_tokens
|
||||||
|
self._registry.restore_filter(self._snapshot.disabled_tools)
|
||||||
|
self._snapshot = None
|
||||||
|
|
||||||
|
if summary:
|
||||||
|
self._ctx.add_message(
|
||||||
|
"system",
|
||||||
|
f"[Skill completed: {skill_name}] {summary}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._active_skill = None
|
||||||
|
logger.info("Skill deactivated: %s", skill_name)
|
||||||
|
|
||||||
|
def _resolve_chain(
|
||||||
|
self, skill: Skill, in_progress: set[str], completed: set[str] | None = None,
|
||||||
|
) -> list[Skill]:
|
||||||
|
"""Depth-first resolution of skill chain dependencies.
|
||||||
|
|
||||||
|
Uses separate in_progress (current path) and completed sets to correctly
|
||||||
|
handle diamond dependencies without false cycle detection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill: The skill whose chain to resolve.
|
||||||
|
in_progress: Skills on the current recursion path (for cycle detection).
|
||||||
|
completed: Skills already fully resolved (skip duplicates).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Ordered list of chained skills to activate before the target.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SkillChainError: If a cycle is detected.
|
||||||
|
"""
|
||||||
|
if completed is None:
|
||||||
|
completed = set()
|
||||||
|
|
||||||
|
if skill.manifest is None or not skill.manifest.chain:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result: list[Skill] = []
|
||||||
|
for dep_name in skill.manifest.chain:
|
||||||
|
if dep_name in completed:
|
||||||
|
continue # Already resolved via another branch (diamond dep)
|
||||||
|
|
||||||
|
if dep_name in in_progress:
|
||||||
|
raise SkillChainError(
|
||||||
|
f"Cycle detected in skill chain: {dep_name} already in progress "
|
||||||
|
f"(path: {' -> '.join(in_progress)} -> {dep_name})"
|
||||||
|
)
|
||||||
|
|
||||||
|
dep_skill = self._skills.get_skill(dep_name)
|
||||||
|
if dep_skill is None:
|
||||||
|
logger.warning("Chained skill not found: %s (required by %s)", dep_name, skill.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
in_progress.add(dep_name)
|
||||||
|
result.extend(self._resolve_chain(dep_skill, in_progress, completed))
|
||||||
|
in_progress.discard(dep_name)
|
||||||
|
completed.add(dep_name)
|
||||||
|
result.append(dep_skill)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _apply_overrides(self, manifest: SkillManifest) -> None:
|
||||||
|
"""Apply config overrides from a skill manifest."""
|
||||||
|
overrides = manifest.config_overrides
|
||||||
|
|
||||||
|
if overrides.temperature is not None:
|
||||||
|
self._config.llm.temperature = overrides.temperature
|
||||||
|
|
||||||
|
if overrides.max_tokens is not None:
|
||||||
|
self._config.llm.max_tokens = overrides.max_tokens
|
||||||
|
|
||||||
|
if overrides.tools_enable is not None or overrides.tools_disable is not None:
|
||||||
|
previous = self._registry.apply_filter(
|
||||||
|
enable=overrides.tools_enable,
|
||||||
|
disable=overrides.tools_disable,
|
||||||
|
)
|
||||||
|
# Store for restoration
|
||||||
|
if self._snapshot:
|
||||||
|
self._snapshot.disabled_tools = previous
|
||||||
162
app/services/skills.py
Normal file
162
app/services/skills.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""Skills manager — scans for and loads skill packages and legacy markdown files."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
from app.models.config import SkillsConfig
|
||||||
|
from app.models.skill import SkillManifest
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Skill(BaseModel):
|
||||||
|
"""Metadata for a discovered skill (package or legacy flat file)."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
path: Path
|
||||||
|
manifest: SkillManifest | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SkillsManager:
|
||||||
|
"""Discovers, indexes, and loads skill files from configured directories.
|
||||||
|
|
||||||
|
Supports both:
|
||||||
|
- Directory-based packages (contain skill.yaml + prompt .md files)
|
||||||
|
- Legacy flat .md files (backwards compatible)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: SkillsConfig, workspace_root: Path) -> None:
|
||||||
|
self._config = config
|
||||||
|
self._workspace = workspace_root
|
||||||
|
self._skills: dict[str, Skill] = {}
|
||||||
|
self._trigger_map: dict[str, str] = {} # trigger -> skill name
|
||||||
|
self._scan()
|
||||||
|
|
||||||
|
def _scan(self) -> None:
|
||||||
|
"""Scan configured directories for skill packages and legacy .md files."""
|
||||||
|
for skill_dir in self._config.directories:
|
||||||
|
resolved = (self._workspace / skill_dir) if not skill_dir.is_absolute() else skill_dir
|
||||||
|
if not resolved.is_dir():
|
||||||
|
logger.debug("Skills directory does not exist: %s", resolved)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for entry in sorted(resolved.iterdir()):
|
||||||
|
if entry.is_dir():
|
||||||
|
self._scan_package(entry)
|
||||||
|
elif entry.suffix == ".md":
|
||||||
|
self._scan_legacy(entry)
|
||||||
|
|
||||||
|
def _scan_package(self, pkg_dir: Path) -> None:
|
||||||
|
"""Scan a directory-based skill package containing skill.yaml."""
|
||||||
|
manifest_path = pkg_dir / "skill.yaml"
|
||||||
|
if not manifest_path.exists():
|
||||||
|
logger.debug("Skipping directory without skill.yaml: %s", pkg_dir)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(manifest_path.read_text())
|
||||||
|
manifest = SkillManifest(**raw)
|
||||||
|
except (yaml.YAMLError, ValidationError, TypeError) as e:
|
||||||
|
logger.warning("Failed to parse skill manifest %s: %s", manifest_path, e)
|
||||||
|
return
|
||||||
|
|
||||||
|
skill = Skill(
|
||||||
|
name=manifest.name,
|
||||||
|
description=manifest.description,
|
||||||
|
path=pkg_dir,
|
||||||
|
manifest=manifest,
|
||||||
|
)
|
||||||
|
self._skills[manifest.name] = skill
|
||||||
|
|
||||||
|
# Register triggers
|
||||||
|
for trigger in manifest.triggers:
|
||||||
|
normalized = trigger.lstrip("/").lower()
|
||||||
|
self._trigger_map[normalized] = manifest.name
|
||||||
|
|
||||||
|
logger.debug("Discovered skill package: %s (%s)", manifest.name, manifest.description)
|
||||||
|
|
||||||
|
def _scan_legacy(self, md_path: Path) -> None:
|
||||||
|
"""Scan a legacy flat .md skill file."""
|
||||||
|
name = md_path.stem
|
||||||
|
desc = self._extract_description(md_path)
|
||||||
|
self._skills[name] = Skill(name=name, description=desc, path=md_path)
|
||||||
|
logger.debug("Discovered legacy skill: %s (%s)", name, desc)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_description(path: Path) -> str:
|
||||||
|
"""Extract the first non-blank, non-heading line as the description."""
|
||||||
|
for line in path.read_text().splitlines():
|
||||||
|
stripped = line.strip()
|
||||||
|
if stripped and not stripped.startswith("#"):
|
||||||
|
return stripped
|
||||||
|
return "(no description)"
|
||||||
|
|
||||||
|
def list_skills(self) -> list[Skill]:
|
||||||
|
"""Return all discovered skills."""
|
||||||
|
return list(self._skills.values())
|
||||||
|
|
||||||
|
def get_skill(self, name: str) -> Skill | None:
|
||||||
|
"""Look up a skill by name."""
|
||||||
|
return self._skills.get(name)
|
||||||
|
|
||||||
|
def get_skill_by_trigger(self, trigger: str) -> Skill | None:
|
||||||
|
"""Look up a skill by /command trigger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trigger: The trigger string (with or without leading /).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The matching Skill, or None.
|
||||||
|
"""
|
||||||
|
normalized = trigger.lstrip("/").lower()
|
||||||
|
skill_name = self._trigger_map.get(normalized)
|
||||||
|
if skill_name:
|
||||||
|
return self._skills.get(skill_name)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_skill(self, name: str) -> str | None:
|
||||||
|
"""Load the full content of a skill by name.
|
||||||
|
|
||||||
|
For package skills, concatenates all prompt .md files.
|
||||||
|
For legacy skills, returns the .md file content.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Concatenated prompt content, or None if not found.
|
||||||
|
"""
|
||||||
|
skill = self._skills.get(name)
|
||||||
|
if skill is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if skill.manifest is not None:
|
||||||
|
# Package skill: load prompt files
|
||||||
|
parts: list[str] = []
|
||||||
|
for prompt_file in skill.manifest.prompts:
|
||||||
|
prompt_path = skill.path / prompt_file
|
||||||
|
if prompt_path.exists():
|
||||||
|
parts.append(prompt_path.read_text())
|
||||||
|
else:
|
||||||
|
logger.warning("Prompt file not found: %s", prompt_path)
|
||||||
|
return "\n\n".join(parts) if parts else None
|
||||||
|
else:
|
||||||
|
# Legacy flat file
|
||||||
|
return skill.path.read_text()
|
||||||
|
|
||||||
|
def get_system_prompt_snippet(self) -> str:
|
||||||
|
"""Generate a snippet for the system prompt listing available skills."""
|
||||||
|
if not self._skills:
|
||||||
|
return ""
|
||||||
|
lines = ["\nAvailable skills (invoke with /skill-name):"]
|
||||||
|
for s in self._skills.values():
|
||||||
|
if s.manifest and s.manifest.triggers:
|
||||||
|
trigger_str = ", ".join(s.manifest.triggers)
|
||||||
|
lines.append(f" - {trigger_str}: {s.description}")
|
||||||
|
else:
|
||||||
|
lines.append(f" - /{s.name}: {s.description}")
|
||||||
|
lines.append("To use a skill's full instructions, call the load_skill tool.")
|
||||||
|
return "\n".join(lines)
|
||||||
@@ -1,41 +1,56 @@
|
|||||||
"""Streaming response handler — accumulates SSE chunks into a complete Message."""
|
"""Streaming response handler — accumulates SSE chunks into a complete Message."""
|
||||||
|
|
||||||
from collections.abc import AsyncIterator
|
import time
|
||||||
|
from collections.abc import AsyncIterator, Callable
|
||||||
from rich.live import Live
|
|
||||||
from rich.markdown import Markdown
|
|
||||||
from rich.panel import Panel
|
|
||||||
|
|
||||||
from app.models.config import DisplayConfig
|
from app.models.config import DisplayConfig
|
||||||
from app.models.message import Message
|
from app.models.message import Message
|
||||||
from app.models.tool_call import ToolCall, ToolCallFunction
|
from app.models.tool_call import ToolCall, ToolCallFunction
|
||||||
from app.utils.logging import console, get_logger
|
from app.utils.logging import get_logger
|
||||||
from app.utils.token_counter import TokenUsage
|
from app.utils.token_counter import TokenUsage
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Minimum interval between content update callbacks (seconds)
|
||||||
|
_UPDATE_THROTTLE_INTERVAL = 0.1
|
||||||
|
|
||||||
|
|
||||||
class StreamHandler:
|
class StreamHandler:
|
||||||
"""Processes an SSE chunk stream into a Rich live display and final Message.
|
"""Processes an SSE chunk stream and produces a complete assistant Message.
|
||||||
|
|
||||||
Accumulates content deltas and tool call fragments, renders a live Markdown
|
Accumulates content deltas and tool call fragments. Notifies the UI via
|
||||||
panel during streaming, and produces a complete assistant Message on finish.
|
optional callbacks during streaming.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, display_config: DisplayConfig) -> None:
|
def __init__(self, display_config: DisplayConfig) -> None:
|
||||||
"""Initialize the stream handler.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
display_config: Display preferences (streaming toggle, etc.).
|
|
||||||
"""
|
|
||||||
self._display_config = display_config
|
self._display_config = display_config
|
||||||
self._accumulated_content: str = ""
|
self._accumulated_content: str = ""
|
||||||
self._accumulated_reasoning: str = ""
|
self._accumulated_reasoning: str = ""
|
||||||
self._tool_calls: dict[int, dict[str, str]] = {}
|
self._tool_calls: dict[int, dict[str, str]] = {}
|
||||||
self._usage: TokenUsage | None = None
|
self._usage: TokenUsage | None = None
|
||||||
|
self._on_content: Callable[[str], None] | None = None
|
||||||
|
self._on_thinking: Callable[[], None] | None = None
|
||||||
|
self._on_done: Callable[[], None] | None = None
|
||||||
|
|
||||||
|
def set_callbacks(
|
||||||
|
self,
|
||||||
|
on_content: Callable[[str], None] | None = None,
|
||||||
|
on_thinking: Callable[[], None] | None = None,
|
||||||
|
on_done: Callable[[], None] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set UI callbacks for streaming updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
on_content: Called with accumulated content string (throttled to ~100ms).
|
||||||
|
on_thinking: Called once when first reasoning token arrives.
|
||||||
|
on_done: Called when streaming is complete.
|
||||||
|
"""
|
||||||
|
self._on_content = on_content
|
||||||
|
self._on_thinking = on_thinking
|
||||||
|
self._on_done = on_done
|
||||||
|
|
||||||
async def process_stream(self, chunk_iter: AsyncIterator[dict]) -> Message:
|
async def process_stream(self, chunk_iter: AsyncIterator[dict]) -> Message:
|
||||||
"""Consume a chunk iterator, rendering live output and returning the final Message.
|
"""Consume a chunk iterator and return the final Message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunk_iter: Async iterator of parsed SSE chunk dicts.
|
chunk_iter: Async iterator of parsed SSE chunk dicts.
|
||||||
@@ -43,28 +58,54 @@ class StreamHandler:
|
|||||||
Returns:
|
Returns:
|
||||||
Complete assistant Message with accumulated content and tool calls.
|
Complete assistant Message with accumulated content and tool calls.
|
||||||
"""
|
"""
|
||||||
with Live(console=console, refresh_per_second=8) as live:
|
thinking_notified = False
|
||||||
async for chunk in chunk_iter:
|
last_update_time = 0.0
|
||||||
self._process_chunk(chunk)
|
chunk_count = 0
|
||||||
|
|
||||||
# Show reasoning while waiting for content
|
async for chunk in chunk_iter:
|
||||||
display_text = self._accumulated_content
|
chunk_count += 1
|
||||||
if not display_text and self._accumulated_reasoning:
|
self._process_chunk(chunk)
|
||||||
display_text = "*thinking...*"
|
|
||||||
|
|
||||||
if display_text and self._display_config.stream_output:
|
if not self._display_config.stream_output:
|
||||||
# Render inside the same Assistant panel used for final output
|
continue
|
||||||
# so the live display and final frame are visually consistent
|
|
||||||
live.update(
|
# Notify thinking once
|
||||||
Panel(
|
if (
|
||||||
Markdown(display_text),
|
not thinking_notified
|
||||||
title="Assistant",
|
and not self._accumulated_content
|
||||||
border_style="green",
|
and self._accumulated_reasoning
|
||||||
expand=True,
|
and self._on_thinking is not None
|
||||||
)
|
):
|
||||||
)
|
self._on_thinking()
|
||||||
|
thinking_notified = True
|
||||||
|
|
||||||
|
# Throttled content updates
|
||||||
|
if self._accumulated_content and self._on_content is not None:
|
||||||
|
now = time.monotonic()
|
||||||
|
if now - last_update_time >= _UPDATE_THROTTLE_INTERVAL:
|
||||||
|
self._on_content(self._accumulated_content)
|
||||||
|
last_update_time = now
|
||||||
|
|
||||||
|
# Final content update (ensures last chunk is shown)
|
||||||
|
if (
|
||||||
|
self._display_config.stream_output
|
||||||
|
and self._accumulated_content
|
||||||
|
and self._on_content is not None
|
||||||
|
):
|
||||||
|
self._on_content(self._accumulated_content)
|
||||||
|
|
||||||
|
if self._on_done is not None:
|
||||||
|
self._on_done()
|
||||||
|
|
||||||
tool_calls = self._build_tool_calls() or None
|
tool_calls = self._build_tool_calls() or None
|
||||||
|
|
||||||
|
if chunk_count > 0 and not self._accumulated_content and not tool_calls:
|
||||||
|
logger.debug(
|
||||||
|
"stream_empty_result",
|
||||||
|
chunks_received=chunk_count,
|
||||||
|
had_reasoning=bool(self._accumulated_reasoning),
|
||||||
|
)
|
||||||
|
|
||||||
return Message(
|
return Message(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=self._accumulated_content or None,
|
content=self._accumulated_content or None,
|
||||||
@@ -72,12 +113,7 @@ class StreamHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _process_chunk(self, chunk: dict) -> None:
|
def _process_chunk(self, chunk: dict) -> None:
|
||||||
"""Extract content, tool calls, and usage from a single SSE chunk.
|
"""Extract content, tool calls, and usage from a single SSE chunk."""
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk: Parsed JSON dict from one SSE data line.
|
|
||||||
"""
|
|
||||||
# Content delta
|
|
||||||
choices = chunk.get("choices", [])
|
choices = chunk.get("choices", [])
|
||||||
if choices:
|
if choices:
|
||||||
delta = choices[0].get("delta", {})
|
delta = choices[0].get("delta", {})
|
||||||
@@ -86,12 +122,10 @@ class StreamHandler:
|
|||||||
if content_piece:
|
if content_piece:
|
||||||
self._accumulated_content += content_piece
|
self._accumulated_content += content_piece
|
||||||
|
|
||||||
# Reasoning tokens (e.g. qwen3.5 thinking mode)
|
|
||||||
reasoning_piece = delta.get("reasoning")
|
reasoning_piece = delta.get("reasoning")
|
||||||
if reasoning_piece:
|
if reasoning_piece:
|
||||||
self._accumulated_reasoning += reasoning_piece
|
self._accumulated_reasoning += reasoning_piece
|
||||||
|
|
||||||
# Tool call deltas (accumulated by index)
|
|
||||||
for tc_delta in delta.get("tool_calls", []):
|
for tc_delta in delta.get("tool_calls", []):
|
||||||
idx = tc_delta.get("index", 0)
|
idx = tc_delta.get("index", 0)
|
||||||
if idx not in self._tool_calls:
|
if idx not in self._tool_calls:
|
||||||
@@ -109,7 +143,6 @@ class StreamHandler:
|
|||||||
if func.get("arguments"):
|
if func.get("arguments"):
|
||||||
entry["arguments"] += func["arguments"]
|
entry["arguments"] += func["arguments"]
|
||||||
|
|
||||||
# Token usage (typically in the final chunk)
|
|
||||||
usage_data = chunk.get("usage")
|
usage_data = chunk.get("usage")
|
||||||
if usage_data:
|
if usage_data:
|
||||||
self._usage = TokenUsage(
|
self._usage = TokenUsage(
|
||||||
@@ -119,11 +152,7 @@ class StreamHandler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _build_tool_calls(self) -> list[ToolCall]:
|
def _build_tool_calls(self) -> list[ToolCall]:
|
||||||
"""Convert accumulated tool call fragments into sorted ToolCall list.
|
"""Convert accumulated tool call fragments into sorted ToolCall list."""
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of ToolCall objects sorted by stream index.
|
|
||||||
"""
|
|
||||||
if not self._tool_calls:
|
if not self._tool_calls:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -142,6 +171,17 @@ class StreamHandler:
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def get_partial_message(self) -> Message | None:
|
||||||
|
"""Return whatever content/tool_calls have been accumulated so far."""
|
||||||
|
tool_calls = self._build_tool_calls() or None
|
||||||
|
if not self._accumulated_content and not tool_calls:
|
||||||
|
return None
|
||||||
|
return Message(
|
||||||
|
role="assistant",
|
||||||
|
content=self._accumulated_content or None,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def usage(self) -> TokenUsage | None:
|
def usage(self) -> TokenUsage | None:
|
||||||
"""Token usage reported by the API, if available."""
|
"""Token usage reported by the API, if available."""
|
||||||
@@ -153,7 +193,7 @@ class StreamHandler:
|
|||||||
return bool(self._accumulated_reasoning) and not self._accumulated_content and not self._tool_calls
|
return bool(self._accumulated_reasoning) and not self._accumulated_content and not self._tool_calls
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Clear all accumulators for the next turn."""
|
"""Clear accumulators for the next LLM call, preserving UI callbacks."""
|
||||||
self._accumulated_content = ""
|
self._accumulated_content = ""
|
||||||
self._accumulated_reasoning = ""
|
self._accumulated_reasoning = ""
|
||||||
self._tool_calls.clear()
|
self._tool_calls.clear()
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
"""Edit tools: str_replace and patch_apply."""
|
"""Edit tools: str_replace and patch_apply."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import subprocess
|
import subprocess
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.models.config import AppConfig
|
||||||
from app.models.tool_call import ToolResult, ToolResultStatus
|
from app.models.tool_call import ToolResult, ToolResultStatus
|
||||||
from app.tools.base import BaseTool
|
from app.tools.base import BaseTool
|
||||||
|
from app.utils.file_cache import FileCache, cached_read_file
|
||||||
from app.utils.file_helpers import (
|
from app.utils.file_helpers import (
|
||||||
FileSizeError,
|
FileSizeError,
|
||||||
PathSecurityError,
|
PathSecurityError,
|
||||||
@@ -37,6 +41,12 @@ class StrReplaceTool(BaseTool):
|
|||||||
)
|
)
|
||||||
params_model = StrReplaceParams
|
params_model = StrReplaceParams
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, workspace_root: Path, config: AppConfig, file_cache: FileCache | None = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__(workspace_root, config)
|
||||||
|
self._file_cache = file_cache
|
||||||
|
|
||||||
def execute(
|
def execute(
|
||||||
self, *, tool_call_id: str, file_path: str, old_str: str, new_str: str, **kwargs: Any
|
self, *, tool_call_id: str, file_path: str, old_str: str, new_str: str, **kwargs: Any
|
||||||
) -> ToolResult:
|
) -> ToolResult:
|
||||||
@@ -44,11 +54,12 @@ class StrReplaceTool(BaseTool):
|
|||||||
|
|
||||||
# Read the file
|
# Read the file
|
||||||
try:
|
try:
|
||||||
content = safe_read_file(
|
content = cached_read_file(
|
||||||
file_path,
|
file_path,
|
||||||
self.workspace_root,
|
self.workspace_root,
|
||||||
max_size_bytes=fs_config.max_file_size_bytes,
|
max_size_bytes=fs_config.max_file_size_bytes,
|
||||||
check_binary=fs_config.binary_detection,
|
check_binary=fs_config.binary_detection,
|
||||||
|
cache=self._file_cache,
|
||||||
)
|
)
|
||||||
except PathSecurityError as exc:
|
except PathSecurityError as exc:
|
||||||
return ToolResult(
|
return ToolResult(
|
||||||
@@ -117,8 +128,14 @@ class StrReplaceTool(BaseTool):
|
|||||||
safe_path = resolve_safe_path(file_path, self.workspace_root)
|
safe_path = resolve_safe_path(file_path, self.workspace_root)
|
||||||
rel_path = safe_path.relative_to(self.workspace_root)
|
rel_path = safe_path.relative_to(self.workspace_root)
|
||||||
except (PathSecurityError, ValueError):
|
except (PathSecurityError, ValueError):
|
||||||
|
safe_path = None
|
||||||
rel_path = Path(file_path)
|
rel_path = Path(file_path)
|
||||||
|
|
||||||
|
# Pre-warm cache with the new content (we already have it in memory).
|
||||||
|
if self._file_cache is not None and safe_path is not None:
|
||||||
|
self._file_cache.invalidate(safe_path)
|
||||||
|
self._file_cache.put(safe_path, new_content)
|
||||||
|
|
||||||
return ToolResult(
|
return ToolResult(
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
tool_name=self.name,
|
tool_name=self.name,
|
||||||
@@ -144,6 +161,12 @@ class PatchApplyTool(BaseTool):
|
|||||||
)
|
)
|
||||||
params_model = PatchApplyParams
|
params_model = PatchApplyParams
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, workspace_root: Path, config: AppConfig, file_cache: FileCache | None = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__(workspace_root, config)
|
||||||
|
self._file_cache = file_cache
|
||||||
|
|
||||||
def execute(self, *, tool_call_id: str, file_path: str, patch: str, **kwargs: Any) -> ToolResult:
|
def execute(self, *, tool_call_id: str, file_path: str, patch: str, **kwargs: Any) -> ToolResult:
|
||||||
try:
|
try:
|
||||||
safe_path = resolve_safe_path(file_path, self.workspace_root)
|
safe_path = resolve_safe_path(file_path, self.workspace_root)
|
||||||
@@ -195,6 +218,9 @@ class PatchApplyTool(BaseTool):
|
|||||||
error=f"Patch failed (exit {result.returncode}): {result.stderr or result.stdout}",
|
error=f"Patch failed (exit {result.returncode}): {result.stderr or result.stdout}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._file_cache is not None:
|
||||||
|
self._file_cache.invalidate(safe_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rel_path = safe_path.relative_to(self.workspace_root)
|
rel_path = safe_path.relative_to(self.workspace_root)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
"""Filesystem tools: read_file, list_dir, write_file, make_dir, delete_file."""
|
"""Filesystem tools: read_file, list_dir, write_file, make_dir, delete_file."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.models.config import AppConfig
|
||||||
from app.models.tool_call import ToolResult, ToolResultStatus
|
from app.models.tool_call import ToolResult, ToolResultStatus
|
||||||
from app.tools.base import BaseTool
|
from app.tools.base import BaseTool
|
||||||
|
from app.utils.file_cache import FileCache, cached_read_file
|
||||||
from app.utils.file_helpers import (
|
from app.utils.file_helpers import (
|
||||||
BinaryFileError,
|
BinaryFileError,
|
||||||
FileSizeError,
|
FileSizeError,
|
||||||
@@ -23,6 +27,12 @@ class ReadFileParams(BaseModel):
|
|||||||
file_path: str = Field(description="Path to the file to read (relative to workspace root)")
|
file_path: str = Field(description="Path to the file to read (relative to workspace root)")
|
||||||
|
|
||||||
|
|
||||||
|
class ReadManyFilesParams(BaseModel):
|
||||||
|
"""Parameters for the read_many_files tool."""
|
||||||
|
|
||||||
|
file_paths: list[str] = Field(description="List of file paths to read (relative to workspace root)")
|
||||||
|
|
||||||
|
|
||||||
class ReadFileTool(BaseTool):
|
class ReadFileTool(BaseTool):
|
||||||
"""Read the contents of a file within the workspace."""
|
"""Read the contents of a file within the workspace."""
|
||||||
|
|
||||||
@@ -30,14 +40,22 @@ class ReadFileTool(BaseTool):
|
|||||||
description = "Read the full contents of a text file. Returns the file content as a string."
|
description = "Read the full contents of a text file. Returns the file content as a string."
|
||||||
params_model = ReadFileParams
|
params_model = ReadFileParams
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, workspace_root: Path, config: AppConfig, file_cache: FileCache | None = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__(workspace_root, config)
|
||||||
|
self._file_cache = file_cache
|
||||||
|
|
||||||
def execute(self, *, tool_call_id: str, file_path: str, **kwargs: Any) -> ToolResult:
|
def execute(self, *, tool_call_id: str, file_path: str, **kwargs: Any) -> ToolResult:
|
||||||
fs_config = self.config.tools.filesystem
|
fs_config = self.config.tools.filesystem
|
||||||
|
hits_before = self._file_cache.stats.hits if self._file_cache else 0
|
||||||
try:
|
try:
|
||||||
content = safe_read_file(
|
content = cached_read_file(
|
||||||
file_path,
|
file_path,
|
||||||
self.workspace_root,
|
self.workspace_root,
|
||||||
max_size_bytes=fs_config.max_file_size_bytes,
|
max_size_bytes=fs_config.max_file_size_bytes,
|
||||||
check_binary=fs_config.binary_detection,
|
check_binary=fs_config.binary_detection,
|
||||||
|
cache=self._file_cache,
|
||||||
)
|
)
|
||||||
except PathSecurityError as exc:
|
except PathSecurityError as exc:
|
||||||
return ToolResult(
|
return ToolResult(
|
||||||
@@ -47,11 +65,12 @@ class ReadFileTool(BaseTool):
|
|||||||
error=str(exc),
|
error=str(exc),
|
||||||
)
|
)
|
||||||
except FileNotFoundError as exc:
|
except FileNotFoundError as exc:
|
||||||
|
filename = Path(file_path).name
|
||||||
return ToolResult(
|
return ToolResult(
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
tool_name=self.name,
|
tool_name=self.name,
|
||||||
status=ToolResultStatus.ERROR,
|
status=ToolResultStatus.ERROR,
|
||||||
error=str(exc),
|
error=f"{exc}. Use find_files to locate it, e.g. find_files(pattern=\"{filename}\")",
|
||||||
)
|
)
|
||||||
except FileSizeError as exc:
|
except FileSizeError as exc:
|
||||||
return ToolResult(
|
return ToolResult(
|
||||||
@@ -68,6 +87,23 @@ class ReadFileTool(BaseTool):
|
|||||||
error=str(exc),
|
error=str(exc),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# On cache hit the file is unchanged — its content is already in
|
||||||
|
# conversation context from the earlier read, so avoid resending it.
|
||||||
|
was_cache_hit = (
|
||||||
|
self._file_cache is not None
|
||||||
|
and self._file_cache.stats.hits > hits_before
|
||||||
|
)
|
||||||
|
if was_cache_hit:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output=(
|
||||||
|
f"[Cached] {file_path} is unchanged since last read "
|
||||||
|
f"({len(content):,} chars). Content is already in conversation context."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
return ToolResult(
|
return ToolResult(
|
||||||
tool_call_id=tool_call_id,
|
tool_call_id=tool_call_id,
|
||||||
tool_name=self.name,
|
tool_name=self.name,
|
||||||
@@ -76,6 +112,76 @@ class ReadFileTool(BaseTool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadManyFilesTool(BaseTool):
|
||||||
|
"""Read contents of multiple files at once."""
|
||||||
|
|
||||||
|
name = "read_many_files"
|
||||||
|
description = (
|
||||||
|
"Read contents of multiple files at once. Returns each file's content "
|
||||||
|
"prefixed with its path header."
|
||||||
|
)
|
||||||
|
params_model = ReadManyFilesParams
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, workspace_root: Path, config: AppConfig, file_cache: FileCache | None = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__(workspace_root, config)
|
||||||
|
self._file_cache = file_cache
|
||||||
|
|
||||||
|
def execute(self, *, tool_call_id: str, file_paths: list[str], **kwargs: Any) -> ToolResult:
|
||||||
|
if not file_paths:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error="file_paths list is empty",
|
||||||
|
)
|
||||||
|
|
||||||
|
fs_config = self.config.tools.filesystem
|
||||||
|
sections: list[str] = []
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
for fp in file_paths:
|
||||||
|
hits_before = self._file_cache.stats.hits if self._file_cache else 0
|
||||||
|
try:
|
||||||
|
content = cached_read_file(
|
||||||
|
fp,
|
||||||
|
self.workspace_root,
|
||||||
|
max_size_bytes=fs_config.max_file_size_bytes,
|
||||||
|
check_binary=fs_config.binary_detection,
|
||||||
|
cache=self._file_cache,
|
||||||
|
)
|
||||||
|
was_hit = (
|
||||||
|
self._file_cache is not None
|
||||||
|
and self._file_cache.stats.hits > hits_before
|
||||||
|
)
|
||||||
|
if was_hit:
|
||||||
|
sections.append(
|
||||||
|
f"=== {fp} ===\n[Cached] Unchanged since last read "
|
||||||
|
f"({len(content):,} chars). Already in conversation context."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sections.append(f"=== {fp} ===\n{content}")
|
||||||
|
success_count += 1
|
||||||
|
except (PathSecurityError, FileNotFoundError, FileSizeError, BinaryFileError) as exc:
|
||||||
|
sections.append(f"=== {fp} ===\n[ERROR] {exc}")
|
||||||
|
|
||||||
|
if success_count == 0:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error="All files failed to read:\n" + "\n".join(sections),
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output="\n".join(sections),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ListDirParams(BaseModel):
|
class ListDirParams(BaseModel):
|
||||||
"""Parameters for the list_dir tool."""
|
"""Parameters for the list_dir tool."""
|
||||||
|
|
||||||
@@ -167,6 +273,12 @@ class WriteFileTool(BaseTool):
|
|||||||
)
|
)
|
||||||
params_model = WriteFileParams
|
params_model = WriteFileParams
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, workspace_root: Path, config: AppConfig, file_cache: FileCache | None = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__(workspace_root, config)
|
||||||
|
self._file_cache = file_cache
|
||||||
|
|
||||||
def execute(self, *, tool_call_id: str, file_path: str, content: str, **kwargs: Any) -> ToolResult:
|
def execute(self, *, tool_call_id: str, file_path: str, content: str, **kwargs: Any) -> ToolResult:
|
||||||
fs_config = self.config.tools.filesystem
|
fs_config = self.config.tools.filesystem
|
||||||
try:
|
try:
|
||||||
@@ -191,6 +303,9 @@ class WriteFileTool(BaseTool):
|
|||||||
error=str(exc),
|
error=str(exc),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._file_cache is not None:
|
||||||
|
self._file_cache.invalidate(safe_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rel_path = safe_path.relative_to(self.workspace_root)
|
rel_path = safe_path.relative_to(self.workspace_root)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -272,6 +387,12 @@ class DeleteFileTool(BaseTool):
|
|||||||
description = "Delete a single file. Does not delete directories."
|
description = "Delete a single file. Does not delete directories."
|
||||||
params_model = DeleteFileParams
|
params_model = DeleteFileParams
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, workspace_root: Path, config: AppConfig, file_cache: FileCache | None = None
|
||||||
|
) -> None:
|
||||||
|
super().__init__(workspace_root, config)
|
||||||
|
self._file_cache = file_cache
|
||||||
|
|
||||||
def execute(self, *, tool_call_id: str, file_path: str, **kwargs: Any) -> ToolResult:
|
def execute(self, *, tool_call_id: str, file_path: str, **kwargs: Any) -> ToolResult:
|
||||||
try:
|
try:
|
||||||
safe_path = resolve_safe_path(file_path, self.workspace_root)
|
safe_path = resolve_safe_path(file_path, self.workspace_root)
|
||||||
@@ -309,6 +430,9 @@ class DeleteFileTool(BaseTool):
|
|||||||
error=f"Failed to delete file: {exc}",
|
error=f"Failed to delete file: {exc}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._file_cache is not None:
|
||||||
|
self._file_cache.invalidate(safe_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rel_path = safe_path.relative_to(self.workspace_root)
|
rel_path = safe_path.relative_to(self.workspace_root)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|||||||
@@ -1,11 +1,17 @@
|
|||||||
"""Tool registration and schema export."""
|
"""Tool registration and schema export."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from app.models.config import AppConfig
|
from app.models.config import AppConfig
|
||||||
from app.tools.base import BaseTool
|
from app.tools.base import BaseTool
|
||||||
|
from app.utils.file_cache import FileCache
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.skills import SkillsManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -15,6 +21,7 @@ class ToolRegistry:
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._tools: dict[str, BaseTool] = {}
|
self._tools: dict[str, BaseTool] = {}
|
||||||
|
self._disabled: set[str] = set()
|
||||||
|
|
||||||
def register(self, tool: BaseTool) -> None:
|
def register(self, tool: BaseTool) -> None:
|
||||||
"""Register a tool instance. Raises ValueError on duplicate name."""
|
"""Register a tool instance. Raises ValueError on duplicate name."""
|
||||||
@@ -24,22 +31,78 @@ class ToolRegistry:
|
|||||||
logger.debug("Registered tool: %s", tool.name)
|
logger.debug("Registered tool: %s", tool.name)
|
||||||
|
|
||||||
def get(self, name: str) -> BaseTool | None:
|
def get(self, name: str) -> BaseTool | None:
|
||||||
"""Look up a tool by name."""
|
"""Look up a tool by name. Returns None if disabled or not found."""
|
||||||
|
if name in self._disabled:
|
||||||
|
return None
|
||||||
return self._tools.get(name)
|
return self._tools.get(name)
|
||||||
|
|
||||||
def get_all(self) -> dict[str, BaseTool]:
|
def get_all(self) -> dict[str, BaseTool]:
|
||||||
"""Return all registered tools."""
|
"""Return all registered tools (excluding disabled)."""
|
||||||
return dict(self._tools)
|
return {k: v for k, v in self._tools.items() if k not in self._disabled}
|
||||||
|
|
||||||
def get_openai_tools_schema(self) -> list[dict[str, Any]]:
|
def get_openai_tools_schema(self) -> list[dict[str, Any]]:
|
||||||
"""Return OpenAI function-calling schemas for all registered tools."""
|
"""Return OpenAI function-calling schemas for all active tools."""
|
||||||
return [tool.get_openai_schema() for tool in self._tools.values()]
|
return [
|
||||||
|
tool.get_openai_schema()
|
||||||
|
for tool in self._tools.values()
|
||||||
|
if tool.name not in self._disabled
|
||||||
|
]
|
||||||
|
|
||||||
|
def apply_filter(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
enable: list[str] | None = None,
|
||||||
|
disable: list[str] | None = None,
|
||||||
|
) -> set[str]:
|
||||||
|
"""Apply a tool filter, returning the previous disabled set for restoration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enable: If set, only these tools (plus always-on tools) are available.
|
||||||
|
disable: Specific tools to disable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The previous disabled set (snapshot for restore).
|
||||||
|
"""
|
||||||
|
previous = set(self._disabled)
|
||||||
|
|
||||||
|
if enable is not None:
|
||||||
|
# Whitelist mode: disable everything not in the enable list
|
||||||
|
self._disabled = {name for name in self._tools if name not in enable}
|
||||||
|
elif disable is not None:
|
||||||
|
# Blacklist mode: add to existing disabled set (preserves global disables)
|
||||||
|
self._disabled = set(self._disabled) | set(disable)
|
||||||
|
else:
|
||||||
|
self._disabled = set()
|
||||||
|
|
||||||
|
return previous
|
||||||
|
|
||||||
|
def restore_filter(self, previous: set[str]) -> None:
|
||||||
|
"""Restore a previous filter state."""
|
||||||
|
self._disabled = previous
|
||||||
|
|
||||||
|
def all_tool_names(self) -> list[str]:
|
||||||
|
"""Return all registered tool names (including disabled)."""
|
||||||
|
return list(self._tools.keys())
|
||||||
|
|
||||||
|
|
||||||
def create_default_registry(workspace_root: Path, config: AppConfig) -> ToolRegistry:
|
def create_default_registry(
|
||||||
"""Create a ToolRegistry populated with all built-in tools."""
|
workspace_root: Path,
|
||||||
|
config: AppConfig,
|
||||||
|
skills_manager: SkillsManager | None = None,
|
||||||
|
skill_runner: object | None = None,
|
||||||
|
file_cache: FileCache | None = None,
|
||||||
|
) -> ToolRegistry:
|
||||||
|
"""Create a ToolRegistry populated with all built-in tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_root: Workspace root path.
|
||||||
|
config: Application configuration.
|
||||||
|
skills_manager: Optional skills manager for skill tools.
|
||||||
|
skill_runner: Optional SkillRunner for package skill activation.
|
||||||
|
file_cache: Optional file cache shared across file-reading tools.
|
||||||
|
"""
|
||||||
# Read tools
|
# Read tools
|
||||||
from app.tools.filesystem import ListDirTool, ReadFileTool
|
from app.tools.filesystem import ListDirTool, ReadFileTool, ReadManyFilesTool
|
||||||
|
|
||||||
# Write tools
|
# Write tools
|
||||||
from app.tools.filesystem import DeleteFileTool, MakeDirTool, WriteFileTool
|
from app.tools.filesystem import DeleteFileTool, MakeDirTool, WriteFileTool
|
||||||
@@ -59,7 +122,8 @@ def create_default_registry(workspace_root: Path, config: AppConfig) -> ToolRegi
|
|||||||
registry = ToolRegistry()
|
registry = ToolRegistry()
|
||||||
|
|
||||||
# Read
|
# Read
|
||||||
registry.register(ReadFileTool(workspace_root, config))
|
registry.register(ReadFileTool(workspace_root, config, file_cache=file_cache))
|
||||||
|
registry.register(ReadManyFilesTool(workspace_root, config, file_cache=file_cache))
|
||||||
registry.register(ListDirTool(workspace_root, config))
|
registry.register(ListDirTool(workspace_root, config))
|
||||||
|
|
||||||
# Search
|
# Search
|
||||||
@@ -67,13 +131,13 @@ def create_default_registry(workspace_root: Path, config: AppConfig) -> ToolRegi
|
|||||||
registry.register(FindFilesTool(workspace_root, config))
|
registry.register(FindFilesTool(workspace_root, config))
|
||||||
|
|
||||||
# Write
|
# Write
|
||||||
registry.register(WriteFileTool(workspace_root, config))
|
registry.register(WriteFileTool(workspace_root, config, file_cache=file_cache))
|
||||||
registry.register(MakeDirTool(workspace_root, config))
|
registry.register(MakeDirTool(workspace_root, config))
|
||||||
registry.register(DeleteFileTool(workspace_root, config))
|
registry.register(DeleteFileTool(workspace_root, config, file_cache=file_cache))
|
||||||
|
|
||||||
# Edit
|
# Edit
|
||||||
registry.register(StrReplaceTool(workspace_root, config))
|
registry.register(StrReplaceTool(workspace_root, config, file_cache=file_cache))
|
||||||
registry.register(PatchApplyTool(workspace_root, config))
|
registry.register(PatchApplyTool(workspace_root, config, file_cache=file_cache))
|
||||||
|
|
||||||
# Shell
|
# Shell
|
||||||
registry.register(RunCommandTool(workspace_root, config))
|
registry.register(RunCommandTool(workspace_root, config))
|
||||||
@@ -81,4 +145,13 @@ def create_default_registry(workspace_root: Path, config: AppConfig) -> ToolRegi
|
|||||||
# Control flow
|
# Control flow
|
||||||
registry.register(FinishTool(workspace_root, config))
|
registry.register(FinishTool(workspace_root, config))
|
||||||
|
|
||||||
|
# Skills (conditional)
|
||||||
|
if skills_manager is not None:
|
||||||
|
from app.services.skill_runner import SkillRunner as SkillRunnerType
|
||||||
|
from app.tools.skills import FinishSkillTool, LoadSkillTool
|
||||||
|
|
||||||
|
runner = skill_runner if isinstance(skill_runner, SkillRunnerType) else None
|
||||||
|
registry.register(LoadSkillTool(workspace_root, config, skills_manager, runner))
|
||||||
|
registry.register(FinishSkillTool(workspace_root, config, runner))
|
||||||
|
|
||||||
return registry
|
return registry
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Shell tool: run_command."""
|
"""Shell tool: run_command."""
|
||||||
|
|
||||||
|
import re
|
||||||
import shlex
|
import shlex
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -11,6 +12,9 @@ from app.tools.base import BaseTool
|
|||||||
|
|
||||||
_DEFAULT_TIMEOUT = 30
|
_DEFAULT_TIMEOUT = 30
|
||||||
|
|
||||||
|
# Detect shell redirects that write to files (>, >>, heredocs)
|
||||||
|
_WRITE_REDIRECT_PATTERN = re.compile(r"(?:>\s*\S|>>|<<)")
|
||||||
|
|
||||||
|
|
||||||
class RunCommandParams(BaseModel):
|
class RunCommandParams(BaseModel):
|
||||||
"""Parameters for the run_command tool."""
|
"""Parameters for the run_command tool."""
|
||||||
@@ -43,6 +47,18 @@ class RunCommandTool(BaseTool):
|
|||||||
error=f"Command denied: matches blocked prefix '{denied}'",
|
error=f"Command denied: matches blocked prefix '{denied}'",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Defense-in-depth: flag file-write redirects in tool result
|
||||||
|
if _WRITE_REDIRECT_PATTERN.search(command):
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=(
|
||||||
|
f"Command contains file-write redirect (>, >>, or <<) "
|
||||||
|
f"which bypasses file-write permissions. Use write_file instead."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Allow check: first token must be in allowed_commands
|
# Allow check: first token must be in allowed_commands
|
||||||
try:
|
try:
|
||||||
tokens = shlex.split(command)
|
tokens = shlex.split(command)
|
||||||
|
|||||||
158
app/tools/skills.py
Normal file
158
app/tools/skills.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""Skill tools — load and finish skills during agent operation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any, ClassVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.models.config import AppConfig
|
||||||
|
from app.models.tool_call import ToolResult, ToolResultStatus
|
||||||
|
from app.tools.base import BaseTool
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.services.skill_runner import SkillRunner
|
||||||
|
from app.services.skills import SkillsManager
|
||||||
|
|
||||||
|
|
||||||
|
class LoadSkillParams(BaseModel):
|
||||||
|
"""Parameters for the load_skill tool."""
|
||||||
|
|
||||||
|
name: str = Field(description="Name of the skill to load")
|
||||||
|
|
||||||
|
|
||||||
|
class LoadSkillTool(BaseTool):
|
||||||
|
"""Load a skill's full instructions by name.
|
||||||
|
|
||||||
|
Use when a skill is relevant to the current task.
|
||||||
|
For package skills, this activates the full skill lifecycle
|
||||||
|
(config overrides, chaining, prompt injection).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: ClassVar[str] = "load_skill"
|
||||||
|
description: ClassVar[str] = (
|
||||||
|
"Load a skill's full instructions by name. "
|
||||||
|
"Use when a skill is relevant to the current task."
|
||||||
|
)
|
||||||
|
params_model: ClassVar[type[BaseModel]] = LoadSkillParams
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workspace_root: Path,
|
||||||
|
config: AppConfig,
|
||||||
|
skills_manager: SkillsManager,
|
||||||
|
skill_runner: SkillRunner | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(workspace_root, config)
|
||||||
|
self._skills = skills_manager
|
||||||
|
self._runner = skill_runner
|
||||||
|
|
||||||
|
def set_skill_runner(self, runner: SkillRunner) -> None:
|
||||||
|
"""Late-bind the SkillRunner (avoids circular init dependencies)."""
|
||||||
|
self._runner = runner
|
||||||
|
|
||||||
|
def execute(self, *, tool_call_id: str, **kwargs: Any) -> ToolResult:
|
||||||
|
skill_name: str = kwargs["name"]
|
||||||
|
|
||||||
|
# Check if skill exists
|
||||||
|
skill = self._skills.get_skill(skill_name)
|
||||||
|
if skill is None:
|
||||||
|
available = [s.name for s in self._skills.list_skills()]
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Unknown skill '{skill_name}'. Available: {available}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# For package skills with a runner, use full activation flow
|
||||||
|
if skill.manifest is not None and self._runner is not None:
|
||||||
|
content = self._runner.activate(skill_name)
|
||||||
|
if content is None:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Failed to activate skill '{skill_name}'",
|
||||||
|
)
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output=f"Skill '{skill_name}' activated.\n\n{content}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Legacy skill: just load content
|
||||||
|
content = self._skills.load_skill(skill_name)
|
||||||
|
if content is None:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error=f"Failed to load skill '{skill_name}'",
|
||||||
|
)
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FinishSkillParams(BaseModel):
|
||||||
|
"""Parameters for the finish_skill tool."""
|
||||||
|
|
||||||
|
summary: str = Field(
|
||||||
|
default="Skill complete.",
|
||||||
|
description="Brief summary of what was accomplished during the skill",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FinishSkillTool(BaseTool):
|
||||||
|
"""Signal that the active skill is complete and should be deactivated.
|
||||||
|
|
||||||
|
Restores config overrides and tool availability to pre-skill state.
|
||||||
|
The agent loop continues after this (unlike the finish tool).
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: ClassVar[str] = "finish_skill"
|
||||||
|
description: ClassVar[str] = (
|
||||||
|
"Call this when the active skill's task is complete. "
|
||||||
|
"Deactivates the skill and restores normal config. "
|
||||||
|
"The conversation continues after this."
|
||||||
|
)
|
||||||
|
params_model: ClassVar[type[BaseModel]] = FinishSkillParams
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workspace_root: Path,
|
||||||
|
config: AppConfig,
|
||||||
|
skill_runner: SkillRunner | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(workspace_root, config)
|
||||||
|
self._runner = skill_runner
|
||||||
|
|
||||||
|
def set_skill_runner(self, runner: SkillRunner) -> None:
|
||||||
|
"""Late-bind the SkillRunner (avoids circular init dependencies)."""
|
||||||
|
self._runner = runner
|
||||||
|
|
||||||
|
def execute(self, *, tool_call_id: str, **kwargs: Any) -> ToolResult:
|
||||||
|
summary: str = kwargs.get("summary", "Skill complete.")
|
||||||
|
|
||||||
|
if self._runner is None or not self._runner.is_active:
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.ERROR,
|
||||||
|
error="No skill is currently active.",
|
||||||
|
)
|
||||||
|
|
||||||
|
skill_name = self._runner.active_skill_name
|
||||||
|
self._runner.deactivate(summary=summary)
|
||||||
|
return ToolResult(
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
tool_name=self.name,
|
||||||
|
status=ToolResultStatus.SUCCESS,
|
||||||
|
output=f"Skill '{skill_name}' completed: {summary}",
|
||||||
|
)
|
||||||
0
app/ui/__init__.py
Normal file
0
app/ui/__init__.py
Normal file
471
app/ui/app.py
Normal file
471
app/ui/app.py
Normal file
@@ -0,0 +1,471 @@
|
|||||||
|
"""SneakyCode Textual TUI application."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.text import Text
|
||||||
|
from textual.app import App, ComposeResult
|
||||||
|
from textual.binding import Binding
|
||||||
|
from textual.widgets import Input, RichLog
|
||||||
|
from textual import work
|
||||||
|
|
||||||
|
from app.agent.context import SessionContext
|
||||||
|
from app.agent.loop import AgentLoop
|
||||||
|
from app.models.config import AgentMode, AppConfig
|
||||||
|
from app.services.llm import LLMClient
|
||||||
|
from app.services.permissions import PermissionsService
|
||||||
|
from app.services.session import SessionManager
|
||||||
|
from app.services.streaming import StreamHandler
|
||||||
|
from app.tools.registry import create_default_registry
|
||||||
|
from app.ui.widgets import (
|
||||||
|
HeaderPanel,
|
||||||
|
HistoryInput,
|
||||||
|
PermissionModal,
|
||||||
|
SessionResumeModal,
|
||||||
|
StatusBar,
|
||||||
|
StreamingStatic,
|
||||||
|
)
|
||||||
|
from app.utils.display import DisplayAdapter
|
||||||
|
from app.utils.logging import get_logger, setup_logging_for_tui
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from textual.worker import Worker
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SneakyCodeApp(App):
|
||||||
|
"""Main TUI application for SneakyCode."""
|
||||||
|
|
||||||
|
TITLE = "SneakyCode"
|
||||||
|
CSS_PATH = "styles.tcss"
|
||||||
|
|
||||||
|
BINDINGS = [
|
||||||
|
Binding("ctrl+c", "cancel_or_quit", "Cancel/Quit", show=False),
|
||||||
|
Binding("ctrl+p", "cycle_mode", "Cycle Mode"),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, config: AppConfig, session_mgr: SessionManager | None = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._config = config
|
||||||
|
self._session_mgr = session_mgr
|
||||||
|
self._ctx: SessionContext | None = None
|
||||||
|
self._agent: AgentLoop | None = None
|
||||||
|
self._client: LLMClient | None = None
|
||||||
|
self._tool_registry = None
|
||||||
|
self._permissions: PermissionsService | None = None
|
||||||
|
self._debug_logger = None
|
||||||
|
self._skills_manager = None
|
||||||
|
self._skill_runner = None
|
||||||
|
self._current_worker: Worker | None = None
|
||||||
|
self._cancel_count = 0
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
yield HeaderPanel(model_name=self._config.llm.model)
|
||||||
|
yield RichLog(id="chat-log", highlight=True, markup=True)
|
||||||
|
yield StreamingStatic("", id="streaming")
|
||||||
|
yield StatusBar()
|
||||||
|
yield HistoryInput(placeholder="Enter your prompt...")
|
||||||
|
|
||||||
|
async def on_mount(self) -> None:
|
||||||
|
"""Initialize agent components after the app is mounted."""
|
||||||
|
setup_logging_for_tui()
|
||||||
|
|
||||||
|
# Apply model profile for the initial model before creating context
|
||||||
|
self._config.apply_model_profile(self._config.llm.model)
|
||||||
|
|
||||||
|
self._ctx = SessionContext(self._config)
|
||||||
|
|
||||||
|
# Create long-lived agent dependencies (reused across turns)
|
||||||
|
self._client = LLMClient(self._config.llm)
|
||||||
|
await self._client.__aenter__()
|
||||||
|
self._permissions = PermissionsService(self._config.permissions, self._config.tools)
|
||||||
|
|
||||||
|
# Create debug logger if enabled
|
||||||
|
if self._config.debug.enabled:
|
||||||
|
from app.services.debug_log import DebugLogger
|
||||||
|
|
||||||
|
log_dir = self._config.agent.workspace_root / self._config.debug.log_dir
|
||||||
|
self._debug_logger = DebugLogger(log_dir, self._config.debug.max_files)
|
||||||
|
|
||||||
|
# Initialize skills system
|
||||||
|
if self._config.skills.enabled:
|
||||||
|
from app.services.skills import SkillsManager
|
||||||
|
|
||||||
|
self._skills_manager = SkillsManager(
|
||||||
|
self._config.skills, self._config.agent.workspace_root
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create file cache if enabled
|
||||||
|
self._file_cache = None
|
||||||
|
fs_cache_cfg = self._config.tools.filesystem.cache
|
||||||
|
if fs_cache_cfg.enabled:
|
||||||
|
from app.utils.file_cache import FileCache
|
||||||
|
|
||||||
|
self._file_cache = FileCache(max_entries=fs_cache_cfg.max_entries)
|
||||||
|
|
||||||
|
# Create tool registry (SkillRunner wired after registry exists)
|
||||||
|
self._tool_registry = create_default_registry(
|
||||||
|
self._config.agent.workspace_root,
|
||||||
|
self._config,
|
||||||
|
skills_manager=self._skills_manager,
|
||||||
|
file_cache=self._file_cache,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create SkillRunner and late-bind it to skill tools
|
||||||
|
if self._skills_manager is not None and self._tool_registry is not None:
|
||||||
|
from app.services.skill_runner import SkillRunner
|
||||||
|
|
||||||
|
self._skill_runner = SkillRunner(
|
||||||
|
self._skills_manager,
|
||||||
|
self._config,
|
||||||
|
self._ctx,
|
||||||
|
self._tool_registry,
|
||||||
|
)
|
||||||
|
# Late-bind runner to skill tools already in the registry
|
||||||
|
load_tool = self._tool_registry.get("load_skill")
|
||||||
|
if load_tool and hasattr(load_tool, "set_skill_runner"):
|
||||||
|
load_tool.set_skill_runner(self._skill_runner)
|
||||||
|
finish_tool = self._tool_registry.get("finish_skill")
|
||||||
|
if finish_tool and hasattr(finish_tool, "set_skill_runner"):
|
||||||
|
finish_tool.set_skill_runner(self._skill_runner)
|
||||||
|
|
||||||
|
# Set up permission prompt callback
|
||||||
|
async def permission_prompt(tool_name: str, description: str) -> bool:
|
||||||
|
return await self._show_permission_modal(tool_name, description)
|
||||||
|
|
||||||
|
self._permissions.set_prompt_callback(permission_prompt)
|
||||||
|
|
||||||
|
# Offer session resume if configured (must run in a worker for push_screen_wait)
|
||||||
|
self._offer_session_resume()
|
||||||
|
|
||||||
|
@work
|
||||||
|
async def _offer_session_resume(self) -> None:
|
||||||
|
"""Offer to resume a previous session, running in a worker for modal support."""
|
||||||
|
if self._session_mgr and self._config.session.offer_resume:
|
||||||
|
saved = self._session_mgr.load_latest()
|
||||||
|
if saved:
|
||||||
|
log = self.query_one("#chat-log", RichLog)
|
||||||
|
msg_count = len(saved.messages)
|
||||||
|
resume = await self.push_screen_wait(SessionResumeModal(msg_count))
|
||||||
|
if resume:
|
||||||
|
self._session_mgr.restore(saved, self._ctx)
|
||||||
|
log.write(Text("Session restored", style="bold green"))
|
||||||
|
else:
|
||||||
|
log.write(Text("Starting fresh session", style="cyan"))
|
||||||
|
self.query_one(HistoryInput).focus()
|
||||||
|
|
||||||
|
async def on_input_submitted(self, event: Input.Submitted) -> None:
|
||||||
|
"""Handle user input submission."""
|
||||||
|
user_input = event.value.strip()
|
||||||
|
if not user_input:
|
||||||
|
return
|
||||||
|
|
||||||
|
event.input.clear()
|
||||||
|
event.input.record(user_input)
|
||||||
|
log = self.query_one("#chat-log", RichLog)
|
||||||
|
|
||||||
|
# Echo user prompt (condensed for multi-line)
|
||||||
|
from app.utils.display import render_user_message
|
||||||
|
log.write(render_user_message(user_input))
|
||||||
|
|
||||||
|
# Handle slash commands
|
||||||
|
if user_input.startswith("/"):
|
||||||
|
await self._handle_slash_command(user_input, log)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Dispatch agent turn as async worker
|
||||||
|
self._cancel_count = 0
|
||||||
|
self._current_worker = self.run_worker(
|
||||||
|
self._run_agent_turn(user_input),
|
||||||
|
name="agent-turn",
|
||||||
|
exclusive=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_slash_command(self, command: str, log: RichLog) -> None:
|
||||||
|
"""Process slash commands."""
|
||||||
|
cmd = command.lower()
|
||||||
|
if cmd == "/help":
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
table = Table(title="SneakyCode Commands", show_lines=False)
|
||||||
|
table.add_column("Command", style="cyan", no_wrap=True)
|
||||||
|
table.add_column("Description")
|
||||||
|
table.add_row("/help", "Show this help message")
|
||||||
|
table.add_row("/quit, /exit, /bye", "Save session and exit")
|
||||||
|
table.add_row("/clear", "Clear conversation history")
|
||||||
|
table.add_row("/history", "Show conversation history")
|
||||||
|
table.add_row("/save", "Manually save session")
|
||||||
|
table.add_row("/session", "Show session info (messages, tokens, start time)")
|
||||||
|
table.add_row("/models, /model", "List available Ollama models")
|
||||||
|
table.add_row("/model <name>", "Switch to a different model")
|
||||||
|
table.add_row("/mode", "Show current agent mode")
|
||||||
|
table.add_row("/mode normal|plan|auto", "Switch agent mode")
|
||||||
|
table.add_row("/skills", "List available skills")
|
||||||
|
table.add_row("/<skill>", "Load a skill by name")
|
||||||
|
log.write(table)
|
||||||
|
elif cmd in ("/quit", "/exit", "/bye"):
|
||||||
|
self._save_session()
|
||||||
|
self.exit()
|
||||||
|
elif cmd == "/clear":
|
||||||
|
if self._ctx:
|
||||||
|
self._ctx.clear_history()
|
||||||
|
log.clear()
|
||||||
|
log.write(Text("✓ Conversation history cleared.", style="bold green"))
|
||||||
|
elif cmd == "/history":
|
||||||
|
if self._ctx:
|
||||||
|
from app.utils.display import render_history
|
||||||
|
log.write(render_history(self._ctx.get_history()))
|
||||||
|
elif cmd == "/save":
|
||||||
|
path = self._save_session()
|
||||||
|
if path:
|
||||||
|
log.write(Text(f"✓ Session saved to {path}", style="bold green"))
|
||||||
|
else:
|
||||||
|
log.write(Text("✗ No session to save", style="bold red"))
|
||||||
|
elif cmd == "/session":
|
||||||
|
if self._ctx:
|
||||||
|
log.write(Text(
|
||||||
|
f"Messages: {self._ctx.message_count} | "
|
||||||
|
f"Tokens: ~{self._ctx.estimated_tokens:,} / {self._ctx.token_counter.budget:,} | "
|
||||||
|
f"Started: {self._ctx.start_time.isoformat()}",
|
||||||
|
style="cyan",
|
||||||
|
))
|
||||||
|
elif cmd.split()[0] in ("/models", "/model"):
|
||||||
|
parts = command.split(maxsplit=1)
|
||||||
|
if len(parts) == 1:
|
||||||
|
# List available models
|
||||||
|
try:
|
||||||
|
from app.services.llm import LLMError
|
||||||
|
|
||||||
|
models = await self._client.list_models()
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
table = Table(title="Available Models", show_lines=False)
|
||||||
|
table.add_column("Model", style="cyan")
|
||||||
|
table.add_column("Size", style="dim")
|
||||||
|
current = self._config.llm.model
|
||||||
|
for m in models:
|
||||||
|
name = m["name"]
|
||||||
|
marker = " (active)" if current in name or name.startswith(current) else ""
|
||||||
|
table.add_row(f"{name}{marker}", m["size"])
|
||||||
|
log.write(table)
|
||||||
|
except Exception as e:
|
||||||
|
log.write(Text(f"Failed to list models: {e}", style="red"))
|
||||||
|
else:
|
||||||
|
new_model = parts[1].strip()
|
||||||
|
self._config.llm.model = new_model
|
||||||
|
if self._session_mgr:
|
||||||
|
self._session_mgr.update_model(new_model)
|
||||||
|
# Apply model-specific profile overrides
|
||||||
|
profile = self._config.apply_model_profile(new_model)
|
||||||
|
if profile and self._ctx:
|
||||||
|
# Update token budget if the profile overrides it
|
||||||
|
self._ctx.token_counter.budget = self._config.agent.max_conversation_tokens
|
||||||
|
self.query_one(HeaderPanel).update_model(new_model)
|
||||||
|
header = self.query_one(HeaderPanel)
|
||||||
|
header.update_tokens(
|
||||||
|
self._ctx.estimated_tokens if self._ctx else 0,
|
||||||
|
self._config.agent.max_conversation_tokens,
|
||||||
|
)
|
||||||
|
msg = f"Switched to model: {new_model}"
|
||||||
|
if profile:
|
||||||
|
overrides = []
|
||||||
|
if profile.max_conversation_tokens is not None:
|
||||||
|
overrides.append(f"tokens={profile.max_conversation_tokens:,}")
|
||||||
|
if profile.thinking is not None:
|
||||||
|
overrides.append(f"thinking={'on' if profile.thinking else 'off'}")
|
||||||
|
if overrides:
|
||||||
|
msg += f" ({', '.join(overrides)})"
|
||||||
|
log.write(Text(msg, style="bold green"))
|
||||||
|
elif cmd.split()[0] == "/mode":
|
||||||
|
parts = command.split(maxsplit=1)
|
||||||
|
if len(parts) == 1:
|
||||||
|
current = self._permissions.mode
|
||||||
|
log.write(Text(f"Current mode: {current.value}", style="cyan"))
|
||||||
|
else:
|
||||||
|
mode_str = parts[1].strip().lower()
|
||||||
|
try:
|
||||||
|
new_mode = AgentMode(mode_str)
|
||||||
|
except ValueError:
|
||||||
|
log.write(Text(f"Unknown mode: {mode_str}. Use normal, plan, or auto.", style="yellow"))
|
||||||
|
return
|
||||||
|
self._permissions.mode = new_mode
|
||||||
|
self.query_one(HeaderPanel).update_mode(new_mode)
|
||||||
|
log.write(Text(f"Switched to {new_mode.value} mode", style="bold green"))
|
||||||
|
elif cmd == "/skills":
|
||||||
|
if self._skills_manager:
|
||||||
|
skills = self._skills_manager.list_skills()
|
||||||
|
if not skills:
|
||||||
|
log.write(Text("No skills found", style="yellow"))
|
||||||
|
else:
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
table = Table(title="Available Skills")
|
||||||
|
table.add_column("Name", style="cyan")
|
||||||
|
table.add_column("Description")
|
||||||
|
for s in skills:
|
||||||
|
table.add_row(f"/{s.name}", s.description)
|
||||||
|
log.write(table)
|
||||||
|
else:
|
||||||
|
log.write(Text("Skills system is disabled", style="yellow"))
|
||||||
|
else:
|
||||||
|
# Try as skill trigger (package skill via SkillRunner)
|
||||||
|
if self._skill_runner and self._skills_manager:
|
||||||
|
skill = self._skills_manager.get_skill_by_trigger(cmd.lstrip("/"))
|
||||||
|
if skill is not None:
|
||||||
|
content = self._skill_runner.activate(skill.name)
|
||||||
|
status_bar = self.query_one(StatusBar)
|
||||||
|
status_bar.set_active_skill(skill.name)
|
||||||
|
log.write(Text(f"Skill activated: {skill.name}", style="bold green"))
|
||||||
|
# Run an agent turn so the LLM sees the skill context
|
||||||
|
self._cancel_count = 0
|
||||||
|
self._current_worker = self.run_worker(
|
||||||
|
self._run_agent_turn(f"[Skill activated: {skill.name}]"),
|
||||||
|
name="agent-turn",
|
||||||
|
exclusive=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Try as legacy skill invocation
|
||||||
|
skill_name = cmd.lstrip("/")
|
||||||
|
if self._skills_manager:
|
||||||
|
content = self._skills_manager.load_skill(skill_name)
|
||||||
|
if content is not None:
|
||||||
|
if self._ctx:
|
||||||
|
self._ctx.add_message("system", f"[Skill: {skill_name}]\n{content}")
|
||||||
|
log.write(Text(f"Loaded skill: {skill_name}", style="bold green"))
|
||||||
|
return
|
||||||
|
log.write(Text(f"Unknown command: {command}", style="yellow"))
|
||||||
|
|
||||||
|
async def _run_agent_turn(self, user_input: str) -> None:
|
||||||
|
"""Run a single agent turn (called as a worker)."""
|
||||||
|
if self._ctx is None or self._client is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
log = self.query_one("#chat-log", RichLog)
|
||||||
|
streaming_widget = self.query_one("#streaming", StreamingStatic)
|
||||||
|
status_bar = self.query_one(StatusBar)
|
||||||
|
display = DisplayAdapter(log)
|
||||||
|
|
||||||
|
# StreamHandler is per-turn (has per-turn accumulators)
|
||||||
|
handler = StreamHandler(self._config.display)
|
||||||
|
|
||||||
|
status_bar.start_streaming()
|
||||||
|
|
||||||
|
# Set up streaming UI callbacks
|
||||||
|
header = self.query_one(HeaderPanel)
|
||||||
|
|
||||||
|
def on_content(content: str) -> None:
|
||||||
|
streaming_widget.update(
|
||||||
|
Panel(Markdown(content), title="Assistant", border_style="green", expand=True)
|
||||||
|
)
|
||||||
|
streaming_widget.show_streaming()
|
||||||
|
stream_tokens = len(content) // 4
|
||||||
|
status_bar.update_stream_tokens(stream_tokens)
|
||||||
|
header.update_tokens(
|
||||||
|
self._ctx.estimated_tokens + stream_tokens,
|
||||||
|
self._ctx.token_counter.budget,
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_thinking() -> None:
|
||||||
|
streaming_widget.update(Text("Thinking...", style="dim"))
|
||||||
|
streaming_widget.show_streaming()
|
||||||
|
|
||||||
|
def on_done() -> None:
|
||||||
|
streaming_widget.hide_streaming()
|
||||||
|
status_bar.stop_streaming()
|
||||||
|
|
||||||
|
handler.set_callbacks(on_content=on_content, on_thinking=on_thinking, on_done=on_done)
|
||||||
|
|
||||||
|
agent = AgentLoop(
|
||||||
|
self._config, self._ctx, self._client, handler,
|
||||||
|
self._tool_registry, self._permissions, display,
|
||||||
|
debug_logger=self._debug_logger,
|
||||||
|
skills_manager=self._skills_manager,
|
||||||
|
skill_runner=self._skill_runner,
|
||||||
|
)
|
||||||
|
|
||||||
|
await agent.run_turn(user_input)
|
||||||
|
|
||||||
|
status_bar.stop_streaming()
|
||||||
|
|
||||||
|
# Update token display in header
|
||||||
|
header = self.query_one(HeaderPanel)
|
||||||
|
header.update_tokens(self._ctx.estimated_tokens, self._ctx.token_counter.budget)
|
||||||
|
|
||||||
|
# Update skill indicator (skill may have been deactivated via finish_skill)
|
||||||
|
if self._skill_runner and not self._skill_runner.is_active:
|
||||||
|
status_bar.set_active_skill(None)
|
||||||
|
elif self._skill_runner and self._skill_runner.is_active:
|
||||||
|
status_bar.set_active_skill(self._skill_runner.active_skill_name)
|
||||||
|
|
||||||
|
# Auto-save
|
||||||
|
if self._config.session.auto_save:
|
||||||
|
self._save_session()
|
||||||
|
|
||||||
|
async def _show_permission_modal(self, tool_name: str, description: str) -> bool:
|
||||||
|
"""Show a modal dialog for tool permission approval."""
|
||||||
|
return await self.push_screen_wait(PermissionModal(tool_name, description))
|
||||||
|
|
||||||
|
def action_cancel_or_quit(self) -> None:
|
||||||
|
"""Handle Ctrl+C: first press cancels worker, second press quits."""
|
||||||
|
self._cancel_count += 1
|
||||||
|
if self._cancel_count >= 2 or self._current_worker is None:
|
||||||
|
self._save_session()
|
||||||
|
self.exit()
|
||||||
|
elif self._current_worker is not None:
|
||||||
|
self._current_worker.cancel()
|
||||||
|
log = self.query_one("#chat-log", RichLog)
|
||||||
|
log.write(Text("⚠ Cancelling... (press Ctrl+C again to quit)", style="yellow"))
|
||||||
|
|
||||||
|
def action_cycle_mode(self) -> None:
|
||||||
|
"""Cycle through agent modes: Normal → Plan → Auto → Normal."""
|
||||||
|
if self._permissions is None:
|
||||||
|
return
|
||||||
|
cycle = {
|
||||||
|
AgentMode.NORMAL: AgentMode.PLAN,
|
||||||
|
AgentMode.PLAN: AgentMode.AUTO,
|
||||||
|
AgentMode.AUTO: AgentMode.NORMAL,
|
||||||
|
}
|
||||||
|
new_mode = cycle[self._permissions.mode]
|
||||||
|
self._permissions.mode = new_mode
|
||||||
|
self.query_one(HeaderPanel).update_mode(new_mode)
|
||||||
|
log = self.query_one("#chat-log", RichLog)
|
||||||
|
log.write(Text(f"Mode: {new_mode.value}", style="bold green"))
|
||||||
|
|
||||||
|
async def on_unmount(self) -> None:
|
||||||
|
"""Clean up the LLM client on app shutdown."""
|
||||||
|
if self._client is not None:
|
||||||
|
await self._client.close()
|
||||||
|
|
||||||
|
def on_worker_state_changed(self, event: Worker.StateChanged) -> None:
|
||||||
|
"""Handle worker completion or failure."""
|
||||||
|
from textual.worker import WorkerState
|
||||||
|
|
||||||
|
if event.worker.name != "agent-turn":
|
||||||
|
return
|
||||||
|
|
||||||
|
if event.state == WorkerState.ERROR:
|
||||||
|
log = self.query_one("#chat-log", RichLog)
|
||||||
|
error = event.worker.error
|
||||||
|
log.write(Text(f"✗ Agent error: {error}", style="bold red"))
|
||||||
|
|
||||||
|
if event.state in (WorkerState.SUCCESS, WorkerState.ERROR, WorkerState.CANCELLED):
|
||||||
|
self._current_worker = None
|
||||||
|
# Hide streaming widget and stop spinner in case they were left active
|
||||||
|
streaming = self.query_one("#streaming", StreamingStatic)
|
||||||
|
streaming.hide_streaming()
|
||||||
|
self.query_one(StatusBar).stop_streaming()
|
||||||
|
|
||||||
|
def _save_session(self) -> Path | None:
|
||||||
|
"""Save session quietly, return path or None."""
|
||||||
|
if self._session_mgr and self._ctx and self._ctx.message_count > 0:
|
||||||
|
try:
|
||||||
|
return self._session_mgr.save(self._ctx)
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning("session_save_failed", error=str(e))
|
||||||
|
return None
|
||||||
62
app/ui/styles.tcss
Normal file
62
app/ui/styles.tcss
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
/* SneakyCode TUI Layout */
|
||||||
|
|
||||||
|
Screen {
|
||||||
|
layout: vertical;
|
||||||
|
}
|
||||||
|
|
||||||
|
#chat-log {
|
||||||
|
height: 1fr;
|
||||||
|
border: none;
|
||||||
|
scrollbar-gutter: stable;
|
||||||
|
}
|
||||||
|
|
||||||
|
#streaming {
|
||||||
|
display: none;
|
||||||
|
height: auto;
|
||||||
|
max-height: 50%;
|
||||||
|
padding: 0 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#streaming.visible {
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Modal dialog */
|
||||||
|
#permission-dialog {
|
||||||
|
width: 60;
|
||||||
|
height: auto;
|
||||||
|
max-height: 12;
|
||||||
|
border: thick $accent;
|
||||||
|
background: $surface;
|
||||||
|
padding: 1 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
.modal-title {
|
||||||
|
text-style: bold;
|
||||||
|
margin-bottom: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.modal-body {
|
||||||
|
margin-bottom: 1;
|
||||||
|
color: $text-muted;
|
||||||
|
}
|
||||||
|
|
||||||
|
.modal-buttons {
|
||||||
|
height: 3;
|
||||||
|
align: center middle;
|
||||||
|
}
|
||||||
|
|
||||||
|
.modal-buttons Button {
|
||||||
|
margin: 0 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* StatusBar styles are in DEFAULT_CSS on the widget itself */
|
||||||
|
|
||||||
|
Input {
|
||||||
|
dock: bottom;
|
||||||
|
margin: 0;
|
||||||
|
border: heavy darkcyan;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* HeaderPanel styles are in DEFAULT_CSS on the widget itself */
|
||||||
295
app/ui/widgets.py
Normal file
295
app/ui/widgets.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
"""Custom Textual widgets for SneakyCode TUI."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from textual.app import ComposeResult
|
||||||
|
from textual.containers import Horizontal, Vertical
|
||||||
|
from textual.events import Key
|
||||||
|
from textual.screen import ModalScreen
|
||||||
|
from textual.timer import Timer
|
||||||
|
from textual.widgets import Button, Input, Static
|
||||||
|
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
from app.models.config import AgentMode
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Header Panel
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class HeaderPanel(Static):
|
||||||
|
"""Single-line header showing model name, agent mode, and token usage."""
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
HeaderPanel {
|
||||||
|
dock: top;
|
||||||
|
height: 1;
|
||||||
|
background: darkcyan;
|
||||||
|
color: $text;
|
||||||
|
padding: 0 2;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str) -> None:
|
||||||
|
super().__init__("")
|
||||||
|
self._model_name = model_name
|
||||||
|
self._mode: AgentMode = AgentMode.NORMAL
|
||||||
|
self._tokens: int = 0
|
||||||
|
self._budget: int = 0
|
||||||
|
|
||||||
|
def on_resize(self) -> None:
|
||||||
|
self._refresh_display()
|
||||||
|
|
||||||
|
def update_model(self, name: str) -> None:
|
||||||
|
"""Update the displayed model name."""
|
||||||
|
self._model_name = name
|
||||||
|
self._refresh_display()
|
||||||
|
|
||||||
|
def update_mode(self, mode: AgentMode) -> None:
|
||||||
|
"""Update the displayed agent mode."""
|
||||||
|
self._mode = mode
|
||||||
|
self._refresh_display()
|
||||||
|
|
||||||
|
def update_tokens(self, tokens: int, budget: int) -> None:
|
||||||
|
"""Update the token usage display."""
|
||||||
|
self._tokens = tokens
|
||||||
|
self._budget = budget
|
||||||
|
self._refresh_display()
|
||||||
|
|
||||||
|
def _refresh_display(self) -> None:
|
||||||
|
"""Rebuild the header text."""
|
||||||
|
left = Text.assemble(
|
||||||
|
("⚡ SneakyCode", "bold"),
|
||||||
|
" │ ",
|
||||||
|
(self._model_name, "bold"),
|
||||||
|
)
|
||||||
|
|
||||||
|
mode_styles = {
|
||||||
|
AgentMode.NORMAL: ("NORMAL", "bold black on white"),
|
||||||
|
AgentMode.PLAN: ("PLAN", "bold black on yellow"),
|
||||||
|
AgentMode.AUTO: ("AUTO", "bold white on red"),
|
||||||
|
}
|
||||||
|
mode_label, mode_style = mode_styles[self._mode]
|
||||||
|
mode_text = Text.assemble((" ", mode_style), (mode_label, mode_style), (" ", mode_style))
|
||||||
|
|
||||||
|
right = Text(f"~{self._tokens:,} / {self._budget:,} tokens")
|
||||||
|
|
||||||
|
# Pad between sections
|
||||||
|
total_content = left.plain + " " + mode_text.plain + " " + right.plain
|
||||||
|
available = self.size.width if self.size.width > 0 else 80
|
||||||
|
gap_left = max(1, (available - len(total_content)) // 2)
|
||||||
|
gap_right = max(1, available - len(total_content) - gap_left)
|
||||||
|
|
||||||
|
full = Text.assemble(
|
||||||
|
left, " " * gap_left, mode_text, " " * gap_right, right,
|
||||||
|
)
|
||||||
|
self.update(full)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Modal Dialogs
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionModal(ModalScreen[bool]):
|
||||||
|
"""Modal dialog for tool permission approval."""
|
||||||
|
|
||||||
|
BINDINGS = [("y", "allow", "Allow"), ("n", "deny", "Deny")]
|
||||||
|
|
||||||
|
def __init__(self, tool_name: str, description: str) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._tool_name = tool_name
|
||||||
|
self._description = description
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
with Vertical(id="permission-dialog"):
|
||||||
|
yield Static(f"Tool: {self._tool_name}", classes="modal-title")
|
||||||
|
yield Static(self._description, classes="modal-body")
|
||||||
|
with Horizontal(classes="modal-buttons"):
|
||||||
|
yield Button("Allow (y)", id="allow", variant="success")
|
||||||
|
yield Button("Deny (n)", id="deny", variant="error")
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
self.dismiss(event.button.id == "allow")
|
||||||
|
|
||||||
|
def action_allow(self) -> None:
|
||||||
|
self.dismiss(True)
|
||||||
|
|
||||||
|
def action_deny(self) -> None:
|
||||||
|
self.dismiss(False)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionResumeModal(ModalScreen[bool]):
|
||||||
|
"""Modal dialog for session resume prompt."""
|
||||||
|
|
||||||
|
BINDINGS = [("y", "resume", "Resume"), ("n", "fresh", "Start Fresh")]
|
||||||
|
|
||||||
|
def __init__(self, msg_count: int) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._msg_count = msg_count
|
||||||
|
|
||||||
|
def compose(self) -> ComposeResult:
|
||||||
|
with Vertical(id="permission-dialog"):
|
||||||
|
yield Static("Resume Session?", classes="modal-title")
|
||||||
|
yield Static(
|
||||||
|
f"Found previous session with {self._msg_count} messages.",
|
||||||
|
classes="modal-body",
|
||||||
|
)
|
||||||
|
with Horizontal(classes="modal-buttons"):
|
||||||
|
yield Button("Resume (y)", id="resume", variant="success")
|
||||||
|
yield Button("Start Fresh (n)", id="fresh", variant="warning")
|
||||||
|
|
||||||
|
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||||
|
self.dismiss(event.button.id == "resume")
|
||||||
|
|
||||||
|
def action_resume(self) -> None:
|
||||||
|
self.dismiss(True)
|
||||||
|
|
||||||
|
def action_fresh(self) -> None:
|
||||||
|
self.dismiss(False)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Input with History
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class HistoryInput(Input):
|
||||||
|
"""Input with up/down arrow history cycling."""
|
||||||
|
|
||||||
|
def __init__(self, **kwargs: object) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._history: list[str] = []
|
||||||
|
self._history_index: int = -1
|
||||||
|
self._draft: str = ""
|
||||||
|
|
||||||
|
def record(self, value: str) -> None:
|
||||||
|
"""Record a submitted value in history."""
|
||||||
|
if value.strip() and (not self._history or self._history[-1] != value):
|
||||||
|
self._history.append(value)
|
||||||
|
self._history_index = -1
|
||||||
|
self._draft = ""
|
||||||
|
|
||||||
|
def on_key(self, event: Key) -> None:
|
||||||
|
if event.key == "up" and self._history:
|
||||||
|
event.prevent_default()
|
||||||
|
if self._history_index == -1:
|
||||||
|
self._draft = self.value
|
||||||
|
self._history_index = len(self._history) - 1
|
||||||
|
elif self._history_index > 0:
|
||||||
|
self._history_index -= 1
|
||||||
|
self.value = self._history[self._history_index]
|
||||||
|
self.cursor_position = len(self.value)
|
||||||
|
elif event.key == "down" and self._history_index != -1:
|
||||||
|
event.prevent_default()
|
||||||
|
self._history_index += 1
|
||||||
|
if self._history_index >= len(self._history):
|
||||||
|
self._history_index = -1
|
||||||
|
self.value = self._draft
|
||||||
|
else:
|
||||||
|
self.value = self._history[self._history_index]
|
||||||
|
self.cursor_position = len(self.value)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Status Bar
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class StatusBar(Static):
|
||||||
|
"""Single-line status bar showing token usage, iteration count, and streaming spinner."""
|
||||||
|
|
||||||
|
_SPINNER = "\u280b\u2819\u2839\u2838\u283c\u2834\u2826\u2827\u2807\u280f"
|
||||||
|
|
||||||
|
DEFAULT_CSS = """
|
||||||
|
StatusBar {
|
||||||
|
dock: bottom;
|
||||||
|
height: 1;
|
||||||
|
background: $surface;
|
||||||
|
color: $text-muted;
|
||||||
|
padding: 0 2;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__("")
|
||||||
|
self._iteration: int = 0
|
||||||
|
self._max_iterations: int = 0
|
||||||
|
self._streaming: bool = False
|
||||||
|
self._spinner_frame: int = 0
|
||||||
|
self._spinner_timer: Timer | None = None
|
||||||
|
self._stream_tokens: int = 0
|
||||||
|
self._active_skill: str | None = None
|
||||||
|
|
||||||
|
def update_iteration(self, iteration: int, max_iterations: int) -> None:
|
||||||
|
"""Update the iteration count display."""
|
||||||
|
self._iteration = iteration
|
||||||
|
self._max_iterations = max_iterations
|
||||||
|
self._refresh_display()
|
||||||
|
|
||||||
|
def start_streaming(self) -> None:
|
||||||
|
"""Start the animated thinking spinner."""
|
||||||
|
self._streaming = True
|
||||||
|
self._spinner_frame = 0
|
||||||
|
self._stream_tokens = 0
|
||||||
|
self._spinner_timer = self.set_interval(0.08, self._tick_spinner)
|
||||||
|
self._refresh_display()
|
||||||
|
|
||||||
|
def stop_streaming(self) -> None:
|
||||||
|
"""Stop the animated thinking spinner."""
|
||||||
|
self._streaming = False
|
||||||
|
if self._spinner_timer:
|
||||||
|
self._spinner_timer.stop()
|
||||||
|
self._spinner_timer = None
|
||||||
|
self._refresh_display()
|
||||||
|
|
||||||
|
def update_stream_tokens(self, tokens: int) -> None:
|
||||||
|
"""Update estimated token count during streaming."""
|
||||||
|
self._stream_tokens = tokens
|
||||||
|
|
||||||
|
def _tick_spinner(self) -> None:
|
||||||
|
self._spinner_frame = (self._spinner_frame + 1) % len(self._SPINNER)
|
||||||
|
self._refresh_display()
|
||||||
|
|
||||||
|
def set_active_skill(self, skill_name: str | None) -> None:
|
||||||
|
"""Set or clear the active skill indicator."""
|
||||||
|
self._active_skill = skill_name
|
||||||
|
self._refresh_display()
|
||||||
|
|
||||||
|
def _refresh_display(self) -> None:
|
||||||
|
"""Rebuild the status bar text."""
|
||||||
|
parts: list[str] = []
|
||||||
|
if self._active_skill:
|
||||||
|
parts.append(f"[Skill: {self._active_skill}]")
|
||||||
|
if self._streaming:
|
||||||
|
spinner = self._SPINNER[self._spinner_frame]
|
||||||
|
parts.append(f"{spinner} Thinking")
|
||||||
|
if self._stream_tokens > 0:
|
||||||
|
parts.append(f"~{self._stream_tokens:,} tokens")
|
||||||
|
if self._max_iterations > 0:
|
||||||
|
parts.append(f"Iteration {self._iteration}/{self._max_iterations}")
|
||||||
|
self.update(Text(" \u2502 ".join(parts), style="dim"))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Streaming Display
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingStatic(Static):
|
||||||
|
"""A Static widget that stays mounted but hidden during non-streaming periods.
|
||||||
|
|
||||||
|
During streaming, call show() to make visible and update() with partial content.
|
||||||
|
When streaming ends, call hide() to conceal.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def show_streaming(self) -> None:
|
||||||
|
"""Make the widget visible for streaming."""
|
||||||
|
self.add_class("visible")
|
||||||
|
|
||||||
|
def hide_streaming(self) -> None:
|
||||||
|
"""Hide the widget and clear content."""
|
||||||
|
self.remove_class("visible")
|
||||||
|
self.update("")
|
||||||
@@ -1,7 +1,19 @@
|
|||||||
"""Rich terminal display helpers for SneakyCode."""
|
"""Rich terminal display helpers for SneakyCode.
|
||||||
|
|
||||||
|
Provides two interfaces:
|
||||||
|
- render_* functions: return Rich renderables (for use in Textual RichLog)
|
||||||
|
- DisplayAdapter: wraps a RichLog widget and provides write_* convenience methods
|
||||||
|
- print_* functions: legacy console.print wrappers (for non-TUI fallback)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
|
from rich.markdown import Markdown
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
|
from rich.text import Text
|
||||||
from rich.theme import Theme
|
from rich.theme import Theme
|
||||||
|
|
||||||
from app.models.message import Message
|
from app.models.message import Message
|
||||||
@@ -23,6 +35,179 @@ SNEAKYCODE_THEME = Theme(
|
|||||||
console.push_theme(SNEAKYCODE_THEME)
|
console.push_theme(SNEAKYCODE_THEME)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from rich.console import RenderableType
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Render functions — return Rich renderables
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def render_user_message(content: str) -> Text:
|
||||||
|
"""Render a condensed user prompt as a single styled line.
|
||||||
|
|
||||||
|
Multi-line input is collapsed to the first line with a line count suffix.
|
||||||
|
Long single lines are truncated.
|
||||||
|
"""
|
||||||
|
lines = content.splitlines()
|
||||||
|
first = lines[0] if lines else content
|
||||||
|
max_len = 120
|
||||||
|
if len(first) > max_len:
|
||||||
|
first = first[:max_len] + "…"
|
||||||
|
suffix = f" (+{len(lines) - 1} lines)" if len(lines) > 1 else ""
|
||||||
|
text = Text()
|
||||||
|
text.append("You: ", style="bold cyan")
|
||||||
|
text.append(first + suffix, style="cyan")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def render_assistant_message(content: str) -> Panel:
|
||||||
|
"""Render an assistant message as a styled panel."""
|
||||||
|
return Panel(Markdown(content), title="Assistant", border_style="green", expand=True)
|
||||||
|
|
||||||
|
|
||||||
|
def render_tool_call(name: str, args: str) -> Text:
|
||||||
|
"""Render a compact tool call line."""
|
||||||
|
truncated_args = args[:80] + "..." if len(args) > 80 else args
|
||||||
|
text = Text()
|
||||||
|
text.append(" ")
|
||||||
|
text.append(name, style="magenta")
|
||||||
|
text.append(f" {truncated_args}", style="dim")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def render_tool_result(name: str, output: str, is_error: bool = False) -> Text:
|
||||||
|
"""Render a compact tool result line."""
|
||||||
|
text = Text()
|
||||||
|
text.append(" ")
|
||||||
|
if is_error:
|
||||||
|
truncated = output[:200] + "..." if len(output) > 200 else output
|
||||||
|
text.append(f"{name}: {truncated}", style="bold red")
|
||||||
|
else:
|
||||||
|
lines = output.count("\n") + 1 if output else 0
|
||||||
|
chars = len(output)
|
||||||
|
text.append(f"{name} — {lines} lines, {chars} chars", style="dim")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def render_iteration_header(iteration: int, max_iterations: int) -> Text:
|
||||||
|
"""Render the current agent loop iteration."""
|
||||||
|
return Text(f"── iteration {iteration}/{max_iterations} ──", style="dim")
|
||||||
|
|
||||||
|
|
||||||
|
def render_token_usage(usage_tokens: int, budget: int) -> Text:
|
||||||
|
"""Render token usage as styled text."""
|
||||||
|
return Text(f"Tokens: ~{usage_tokens:,} / {budget:,}", style="dim")
|
||||||
|
|
||||||
|
|
||||||
|
def render_warning(message: str) -> Text:
|
||||||
|
"""Render a warning message."""
|
||||||
|
return Text(f"⚠ {message}", style="yellow")
|
||||||
|
|
||||||
|
|
||||||
|
def render_error(message: str) -> Text:
|
||||||
|
"""Render an error message."""
|
||||||
|
return Text(f"✗ {message}", style="bold red")
|
||||||
|
|
||||||
|
|
||||||
|
def render_success(message: str) -> Text:
|
||||||
|
"""Render a success message."""
|
||||||
|
return Text(f"✓ {message}", style="bold green")
|
||||||
|
|
||||||
|
|
||||||
|
def render_info(message: str) -> Text:
|
||||||
|
"""Render an info message."""
|
||||||
|
return Text(message, style="cyan")
|
||||||
|
|
||||||
|
|
||||||
|
def render_history(messages: list[Message]) -> Table | Text:
|
||||||
|
"""Render conversation history as a Rich table."""
|
||||||
|
if not messages:
|
||||||
|
return Text("No messages in history.", style="dim")
|
||||||
|
|
||||||
|
table = Table(title="Conversation History")
|
||||||
|
table.add_column("#", style="dim", width=4)
|
||||||
|
table.add_column("Role", width=10)
|
||||||
|
table.add_column("Content")
|
||||||
|
|
||||||
|
role_styles = {
|
||||||
|
"user": "cyan",
|
||||||
|
"assistant": "green",
|
||||||
|
"system": "yellow",
|
||||||
|
"tool": "magenta",
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, msg in enumerate(messages, 1):
|
||||||
|
style = role_styles.get(msg.role, "white")
|
||||||
|
content = msg.content or "(no content)"
|
||||||
|
if len(content) > 120:
|
||||||
|
content = content[:117] + "..."
|
||||||
|
table.add_row(str(i), f"[{style}]{msg.role}[/{style}]", content)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DisplayAdapter — wraps a writable target (RichLog or similar)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class WritableLog(Protocol):
|
||||||
|
"""Protocol for anything that accepts Rich renderables via write()."""
|
||||||
|
|
||||||
|
def write(self, content: "RenderableType") -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DisplayAdapter:
|
||||||
|
"""Bridges agent loop display calls to a RichLog widget.
|
||||||
|
|
||||||
|
All write_* methods call the corresponding render_* function and
|
||||||
|
write the result to the target log.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, log: WritableLog) -> None:
|
||||||
|
self._log = log
|
||||||
|
|
||||||
|
def write_user_message(self, content: str) -> None:
|
||||||
|
self._log.write(render_user_message(content))
|
||||||
|
|
||||||
|
def write_assistant_message(self, content: str) -> None:
|
||||||
|
self._log.write(render_assistant_message(content))
|
||||||
|
|
||||||
|
def write_tool_call(self, name: str, args: str) -> None:
|
||||||
|
self._log.write(render_tool_call(name, args))
|
||||||
|
|
||||||
|
def write_tool_result(self, name: str, output: str, is_error: bool = False) -> None:
|
||||||
|
self._log.write(render_tool_result(name, output, is_error))
|
||||||
|
|
||||||
|
def write_iteration_header(self, iteration: int, max_iterations: int) -> None:
|
||||||
|
self._log.write(render_iteration_header(iteration, max_iterations))
|
||||||
|
|
||||||
|
def write_token_usage(self, usage_tokens: int, budget: int) -> None:
|
||||||
|
self._log.write(render_token_usage(usage_tokens, budget))
|
||||||
|
|
||||||
|
def write_warning(self, message: str) -> None:
|
||||||
|
self._log.write(render_warning(message))
|
||||||
|
|
||||||
|
def write_error(self, message: str) -> None:
|
||||||
|
self._log.write(render_error(message))
|
||||||
|
|
||||||
|
def write_success(self, message: str) -> None:
|
||||||
|
self._log.write(render_success(message))
|
||||||
|
|
||||||
|
def write_info(self, message: str) -> None:
|
||||||
|
self._log.write(render_info(message))
|
||||||
|
|
||||||
|
def write_history(self, messages: list[Message]) -> None:
|
||||||
|
self._log.write(render_history(messages))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Legacy print functions — for non-TUI fallback and pre-TUI startup
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def print_banner() -> None:
|
def print_banner() -> None:
|
||||||
"""Print the SneakyCode startup banner."""
|
"""Print the SneakyCode startup banner."""
|
||||||
console.print(
|
console.print(
|
||||||
@@ -51,8 +236,8 @@ def print_success(message: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def print_user_message(content: str) -> None:
|
def print_user_message(content: str) -> None:
|
||||||
"""Print a user message in a styled panel."""
|
"""Print a condensed user prompt line."""
|
||||||
console.print(Panel(content, title="You", border_style="cyan", expand=False))
|
console.print(render_user_message(content))
|
||||||
|
|
||||||
|
|
||||||
def print_assistant_message(content: str) -> None:
|
def print_assistant_message(content: str) -> None:
|
||||||
@@ -61,25 +246,17 @@ def print_assistant_message(content: str) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def print_tool_call(name: str, args: str) -> None:
|
def print_tool_call(name: str, args: str) -> None:
|
||||||
"""Print a compact tool call line — tool name + truncated key args."""
|
"""Print a compact tool call line."""
|
||||||
truncated_args = args[:80] + "..." if len(args) > 80 else args
|
truncated_args = args[:80] + "..." if len(args) > 80 else args
|
||||||
console.print(f" [tool]{name}[/tool] [dim]{truncated_args}[/dim]")
|
console.print(f" [tool]{name}[/tool] [dim]{truncated_args}[/dim]")
|
||||||
|
|
||||||
|
|
||||||
def print_tool_result(name: str, output: str, is_error: bool = False) -> None:
|
def print_tool_result(name: str, output: str, is_error: bool = False) -> None:
|
||||||
"""Print a compact tool result — status line only for success, detail for errors.
|
"""Print a compact tool result."""
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Tool name.
|
|
||||||
output: Tool output or error message.
|
|
||||||
is_error: Whether this is an error result.
|
|
||||||
"""
|
|
||||||
if is_error:
|
if is_error:
|
||||||
# Errors are shown prominently so the user knows something went wrong
|
|
||||||
truncated = output[:200] + "..." if len(output) > 200 else output
|
truncated = output[:200] + "..." if len(output) > 200 else output
|
||||||
console.print(f" [error]{name}: {truncated}[/error]")
|
console.print(f" [error]{name}: {truncated}[/error]")
|
||||||
else:
|
else:
|
||||||
# Success: just show a compact byte/line summary
|
|
||||||
lines = output.count("\n") + 1 if output else 0
|
lines = output.count("\n") + 1 if output else 0
|
||||||
chars = len(output)
|
chars = len(output)
|
||||||
console.print(f" [dim]{name} — {lines} lines, {chars} chars[/dim]")
|
console.print(f" [dim]{name} — {lines} lines, {chars} chars[/dim]")
|
||||||
@@ -96,33 +273,5 @@ def print_token_usage(usage_tokens: int, budget: int) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def print_history(messages: list[Message]) -> None:
|
def print_history(messages: list[Message]) -> None:
|
||||||
"""Print conversation history as a Rich table.
|
"""Print conversation history as a Rich table."""
|
||||||
|
console.print(render_history(messages))
|
||||||
Args:
|
|
||||||
messages: List of conversation messages to display.
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
console.print("[dim]No messages in history.[/dim]")
|
|
||||||
return
|
|
||||||
|
|
||||||
table = Table(title="Conversation History")
|
|
||||||
table.add_column("#", style="dim", width=4)
|
|
||||||
table.add_column("Role", width=10)
|
|
||||||
table.add_column("Content")
|
|
||||||
|
|
||||||
role_styles = {
|
|
||||||
"user": "cyan",
|
|
||||||
"assistant": "green",
|
|
||||||
"system": "yellow",
|
|
||||||
"tool": "magenta",
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, msg in enumerate(messages, 1):
|
|
||||||
style = role_styles.get(msg.role, "white")
|
|
||||||
content = msg.content or "[dim](no content)[/dim]"
|
|
||||||
# Truncate long content for display
|
|
||||||
if len(content) > 120:
|
|
||||||
content = content[:117] + "..."
|
|
||||||
table.add_row(str(i), f"[{style}]{msg.role}[/{style}]", content)
|
|
||||||
|
|
||||||
console.print(table)
|
|
||||||
|
|||||||
185
app/utils/file_cache.py
Normal file
185
app/utils/file_cache.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
"""File cache with LRU eviction and mtime-based invalidation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from app.utils.file_helpers import (
|
||||||
|
BinaryFileError,
|
||||||
|
FileSizeError,
|
||||||
|
PathSecurityError,
|
||||||
|
check_file_size,
|
||||||
|
is_binary_file,
|
||||||
|
resolve_safe_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class CacheEntry:
|
||||||
|
"""A cached file's content and modification timestamp."""
|
||||||
|
|
||||||
|
content: str
|
||||||
|
mtime_ns: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheStats:
|
||||||
|
"""Running statistics for a FileCache instance."""
|
||||||
|
|
||||||
|
hits: int = 0
|
||||||
|
misses: int = 0
|
||||||
|
invalidations: int = 0
|
||||||
|
evictions: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hit_rate(self) -> float:
|
||||||
|
"""Return cache hit rate as a float between 0.0 and 1.0."""
|
||||||
|
total = self.hits + self.misses
|
||||||
|
if total == 0:
|
||||||
|
return 0.0
|
||||||
|
return self.hits / total
|
||||||
|
|
||||||
|
|
||||||
|
class FileCache:
|
||||||
|
"""LRU file-content cache with mtime-based invalidation.
|
||||||
|
|
||||||
|
Keyed by resolved absolute ``Path``. Each lookup performs a cheap
|
||||||
|
``stat()`` syscall to verify the file hasn't changed on disk — if the
|
||||||
|
nanosecond mtime differs the entry is evicted and the caller gets a
|
||||||
|
cache miss.
|
||||||
|
|
||||||
|
Not thread-safe (single-threaded agent loop).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_entries: int = 128) -> None:
|
||||||
|
self._max_entries = max_entries
|
||||||
|
self._entries: OrderedDict[Path, CacheEntry] = OrderedDict()
|
||||||
|
self._stats = CacheStats()
|
||||||
|
|
||||||
|
# -- public API --------------------------------------------------
|
||||||
|
|
||||||
|
def get(self, path: Path) -> str | None:
|
||||||
|
"""Return cached content if *path* hasn't changed, else ``None``.
|
||||||
|
|
||||||
|
A ``stat()`` call checks ``st_mtime_ns``; on mismatch the stale
|
||||||
|
entry is silently removed.
|
||||||
|
"""
|
||||||
|
entry = self._entries.get(path)
|
||||||
|
if entry is None:
|
||||||
|
self._stats.misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_mtime_ns = path.stat().st_mtime_ns
|
||||||
|
except OSError:
|
||||||
|
# File gone — evict and miss.
|
||||||
|
self._remove(path)
|
||||||
|
self._stats.misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
if current_mtime_ns != entry.mtime_ns:
|
||||||
|
self._remove(path)
|
||||||
|
self._stats.invalidations += 1
|
||||||
|
self._stats.misses += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Cache hit — move to end (most-recently used).
|
||||||
|
self._entries.move_to_end(path)
|
||||||
|
self._stats.hits += 1
|
||||||
|
return entry.content
|
||||||
|
|
||||||
|
def put(self, path: Path, content: str) -> None:
|
||||||
|
"""Store *content* for *path* with its current ``st_mtime_ns``.
|
||||||
|
|
||||||
|
Evicts the least-recently-used entry when over capacity.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
mtime_ns = path.stat().st_mtime_ns
|
||||||
|
except OSError:
|
||||||
|
# Can't stat — don't cache.
|
||||||
|
return
|
||||||
|
|
||||||
|
if path in self._entries:
|
||||||
|
# Update existing; move to end.
|
||||||
|
self._entries[path] = CacheEntry(content=content, mtime_ns=mtime_ns)
|
||||||
|
self._entries.move_to_end(path)
|
||||||
|
else:
|
||||||
|
self._entries[path] = CacheEntry(content=content, mtime_ns=mtime_ns)
|
||||||
|
|
||||||
|
# Evict LRU if over capacity.
|
||||||
|
while len(self._entries) > self._max_entries:
|
||||||
|
self._entries.popitem(last=False)
|
||||||
|
self._stats.evictions += 1
|
||||||
|
|
||||||
|
def invalidate(self, path: Path) -> None:
|
||||||
|
"""Remove *path* from the cache if present."""
|
||||||
|
if path in self._entries:
|
||||||
|
del self._entries[path]
|
||||||
|
self._stats.invalidations += 1
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Remove all entries."""
|
||||||
|
self._entries.clear()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stats(self) -> CacheStats:
|
||||||
|
"""Return the running cache statistics."""
|
||||||
|
return self._stats
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._entries)
|
||||||
|
|
||||||
|
# -- internals ---------------------------------------------------
|
||||||
|
|
||||||
|
def _remove(self, path: Path) -> None:
|
||||||
|
"""Delete an entry without bumping invalidation stats."""
|
||||||
|
self._entries.pop(path, None)
|
||||||
|
|
||||||
|
|
||||||
|
def cached_read_file(
|
||||||
|
path: str | Path,
|
||||||
|
workspace_root: Path,
|
||||||
|
max_size_bytes: int = 1_048_576,
|
||||||
|
check_binary: bool = True,
|
||||||
|
cache: FileCache | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Read a file with full security checks, using *cache* when available.
|
||||||
|
|
||||||
|
Security checks (path sandboxing, size limit, binary detection) run on
|
||||||
|
**every** call — only the ``Path.read_text()`` I/O is skipped on a cache
|
||||||
|
hit.
|
||||||
|
|
||||||
|
When *cache* is ``None`` this behaves identically to
|
||||||
|
:func:`~app.utils.file_helpers.safe_read_file`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PathSecurityError: If the path escapes the workspace.
|
||||||
|
FileSizeError: If the file is too large.
|
||||||
|
BinaryFileError: If the file is binary and *check_binary* is True.
|
||||||
|
FileNotFoundError: If the file does not exist.
|
||||||
|
"""
|
||||||
|
safe_path = resolve_safe_path(path, workspace_root)
|
||||||
|
|
||||||
|
if not safe_path.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {safe_path}")
|
||||||
|
|
||||||
|
check_file_size(safe_path, max_size_bytes)
|
||||||
|
|
||||||
|
if check_binary and is_binary_file(safe_path):
|
||||||
|
raise BinaryFileError(f"File appears to be binary: {safe_path}")
|
||||||
|
|
||||||
|
# Try cache.
|
||||||
|
if cache is not None:
|
||||||
|
cached = cache.get(safe_path)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
|
|
||||||
|
# Cache miss (or no cache) — read from disk.
|
||||||
|
content = safe_path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
cache.put(safe_path, content)
|
||||||
|
|
||||||
|
return content
|
||||||
@@ -77,6 +77,21 @@ def setup_logging(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging_for_tui() -> None:
|
||||||
|
"""Reconfigure logging for Textual TUI mode.
|
||||||
|
|
||||||
|
Removes the RichHandler (which writes to stdout and corrupts the TUI)
|
||||||
|
while preserving any file handlers. Call this from the Textual App's
|
||||||
|
on_mount() before any agent work begins.
|
||||||
|
"""
|
||||||
|
from rich.logging import RichHandler
|
||||||
|
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
root_logger.handlers = [
|
||||||
|
h for h in root_logger.handlers if not isinstance(h, RichHandler)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
|
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
|
||||||
"""Get a named structlog logger.
|
"""Get a named structlog logger.
|
||||||
|
|
||||||
|
|||||||
@@ -36,6 +36,11 @@ class TokenCounter:
|
|||||||
"""The configured token budget."""
|
"""The configured token budget."""
|
||||||
return self._budget
|
return self._budget
|
||||||
|
|
||||||
|
@budget.setter
|
||||||
|
def budget(self, value: int) -> None:
|
||||||
|
"""Update the token budget (e.g., when switching models)."""
|
||||||
|
self._budget = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cumulative_usage(self) -> TokenUsage:
|
def cumulative_usage(self) -> TokenUsage:
|
||||||
"""Cumulative token usage across all tracked calls."""
|
"""Cumulative token usage across all tracked calls."""
|
||||||
|
|||||||
@@ -7,11 +7,34 @@ llm:
|
|||||||
temperature: 0.1
|
temperature: 0.1
|
||||||
max_tokens: 4096
|
max_tokens: 4096
|
||||||
timeout: 120
|
timeout: 120
|
||||||
|
max_retries: 3
|
||||||
|
retry_backoff_base: 1.0
|
||||||
|
retry_backoff_max: 30.0
|
||||||
|
thinking: false # Disable model thinking/reasoning mode (reduces reasoning-only loops)
|
||||||
|
# Extra parameters merged into the API request body (model-specific).
|
||||||
|
# Examples:
|
||||||
|
# OpenAI: reasoning_effort: "low"
|
||||||
|
extra_body: {}
|
||||||
|
|
||||||
agent:
|
agent:
|
||||||
max_iterations: 25
|
max_iterations: 25
|
||||||
max_conversation_tokens: 32000
|
max_conversation_tokens: 32000 # Default token budget (overridden by model_profiles)
|
||||||
workspace_root: "."
|
workspace_root: "."
|
||||||
|
truncation_keep_recent: 10
|
||||||
|
truncation_threshold: 0.85
|
||||||
|
|
||||||
|
# Per-model overrides — matched by longest model name prefix.
|
||||||
|
# Unset fields fall through to the defaults above.
|
||||||
|
model_profiles:
|
||||||
|
llama3:
|
||||||
|
max_conversation_tokens: 120000
|
||||||
|
thinking: false
|
||||||
|
qwen:
|
||||||
|
max_conversation_tokens: 32000
|
||||||
|
thinking: false
|
||||||
|
qwq:
|
||||||
|
max_conversation_tokens: 32000
|
||||||
|
thinking: true
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
auto_approve:
|
auto_approve:
|
||||||
@@ -38,7 +61,6 @@ tools:
|
|||||||
- pytest
|
- pytest
|
||||||
- ruff
|
- ruff
|
||||||
- ls
|
- ls
|
||||||
- cat
|
|
||||||
- head
|
- head
|
||||||
- tail
|
- tail
|
||||||
- wc
|
- wc
|
||||||
@@ -46,6 +68,10 @@ tools:
|
|||||||
- grep
|
- grep
|
||||||
- find
|
- find
|
||||||
- echo
|
- echo
|
||||||
|
- which
|
||||||
|
- jq
|
||||||
|
- type
|
||||||
|
- file
|
||||||
denied_commands:
|
denied_commands:
|
||||||
- rm -rf /
|
- rm -rf /
|
||||||
- sudo
|
- sudo
|
||||||
@@ -55,8 +81,27 @@ tools:
|
|||||||
filesystem:
|
filesystem:
|
||||||
max_file_size_bytes: 1048576 # 1 MB
|
max_file_size_bytes: 1048576 # 1 MB
|
||||||
binary_detection: true
|
binary_detection: true
|
||||||
|
cache:
|
||||||
|
enabled: true
|
||||||
|
max_entries: 128
|
||||||
|
|
||||||
|
session:
|
||||||
|
session_dir: ".sneakycode/sessions"
|
||||||
|
auto_save: true
|
||||||
|
max_session_age_hours: 72
|
||||||
|
offer_resume: true
|
||||||
|
|
||||||
display:
|
display:
|
||||||
show_tool_calls: true
|
show_tool_calls: true
|
||||||
show_token_usage: true
|
show_token_usage: true
|
||||||
stream_output: true
|
stream_output: true
|
||||||
|
|
||||||
|
debug:
|
||||||
|
enabled: false
|
||||||
|
log_dir: ".sneakycode/logs"
|
||||||
|
max_files: 10
|
||||||
|
|
||||||
|
skills:
|
||||||
|
enabled: true
|
||||||
|
directories:
|
||||||
|
- ".sneakycode/skills"
|
||||||
|
|||||||
144
docs/ROADMAP.md
144
docs/ROADMAP.md
@@ -1,144 +0,0 @@
|
|||||||
# SneakyCode Implementation Roadmap
|
|
||||||
|
|
||||||
A phased plan progressing from bare-bones foundation to full autonomous coding agent.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 1 — Foundation: Models, Config, and Utilities
|
|
||||||
|
|
||||||
Establish the data layer and shared infrastructure everything else builds on.
|
|
||||||
|
|
||||||
| File | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `app/models/config.py` | Pydantic v2 config model — load and validate `config/config.yaml` |
|
|
||||||
| `app/models/message.py` | Message schema (role, content, tool_calls) |
|
|
||||||
| `app/models/tool_call.py` | ToolCall and ToolResult schemas |
|
|
||||||
| `app/utils/logging.py` | Centralized logger with Rich handler |
|
|
||||||
| `app/utils/display.py` | Rich console output helpers (stub — expanded in Phase 2) |
|
|
||||||
| `app/utils/file_helpers.py` | Safe path resolution, binary detection, size guards |
|
|
||||||
| `app/utils/token_counter.py` | Approximate token usage tracking (character-based heuristic for v1) |
|
|
||||||
| `app/main.py` | Entrypoint stub — arg parsing, config load, Rich console setup |
|
|
||||||
|
|
||||||
**Exit criteria:** `python -m app.main --help` runs, config loads and validates, models can be instantiated and serialized.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 2 — TUI and Interactive Shell
|
|
||||||
|
|
||||||
Get a working interactive terminal before wiring up the LLM.
|
|
||||||
|
|
||||||
| File | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `app/main.py` | Rich-based interactive REPL loop — prompt for user input, display responses |
|
|
||||||
| `app/utils/display.py` | Formatted output for agent messages, tool calls, errors, token usage |
|
|
||||||
| `app/agent/context.py` | Session state and conversation history management |
|
|
||||||
|
|
||||||
**Exit criteria:** User can type messages into a styled REPL, see them echoed back with formatting, and conversation history is tracked in memory.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 3 — LLM Integration (Ollama)
|
|
||||||
|
|
||||||
Connect to the local LLM and stream responses into the TUI.
|
|
||||||
|
|
||||||
| File | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `app/services/llm.py` | Async httpx client wrapping Ollama's OpenAI-compatible `/v1/chat/completions` endpoint |
|
|
||||||
| `app/services/streaming.py` | SSE parsing, Rich live display, tool call extraction from accumulated stream |
|
|
||||||
|
|
||||||
**Integration:** Wire LLM into the REPL — user message goes to LLM, streamed response displays in real time.
|
|
||||||
|
|
||||||
**Exit criteria:** User can chat with the local model through the TUI with streamed output. Tool call JSON is parsed from the stream but not yet executed.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 4 — Tool Framework and Core Tools
|
|
||||||
|
|
||||||
Build the tool abstraction and implement safe, read-only tools first.
|
|
||||||
|
|
||||||
| File | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `app/tools/base.py` | `BaseTool` ABC and `ToolResult` dataclass |
|
|
||||||
| `app/tools/registry.py` | Tool registration, discovery, and JSON schema export for LLM system prompt |
|
|
||||||
| `app/services/permissions.py` | Two-tier approval gating (auto-approve reads; prompt for writes/deletes/shell) |
|
|
||||||
| `app/tools/filesystem.py` | `read_file`, `list_dir` |
|
|
||||||
| `app/tools/search.py` | `grep_files`, `find_files` |
|
|
||||||
|
|
||||||
**Exit criteria:** Tools register themselves, schemas export correctly for inclusion in the system prompt, read-only tools execute and return `ToolResult` objects. Permissions service gates execution.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 5 — Agent Loop (ReAct)
|
|
||||||
|
|
||||||
The core autonomy layer — reason, act, observe, repeat.
|
|
||||||
|
|
||||||
| File | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `app/agent/loop.py` | ReAct cycle: send conversation to LLM, parse tool calls, execute, feed results back, repeat |
|
|
||||||
|
|
||||||
**Key behaviors:**
|
|
||||||
- System prompt constructed with tool schemas from registry
|
|
||||||
- Permissions checks before each tool execution
|
|
||||||
- Loop termination on: plain-text response (no tool calls), explicit `finish` tool call, or `max_iterations` exceeded
|
|
||||||
|
|
||||||
**Exit criteria:** Agent can autonomously answer questions about the codebase by chaining `read_file`, `list_dir`, `grep_files`, and `find_files` tool calls in a multi-turn loop.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 6 — Write Tools and Shell
|
|
||||||
|
|
||||||
Unlock the agent's ability to modify code and run commands.
|
|
||||||
|
|
||||||
| File | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `app/tools/filesystem.py` | `write_file`, `make_dir`, `delete_file` (additions to existing module) |
|
|
||||||
| `app/tools/edit.py` | `str_replace` (unique-match required), `patch_apply` |
|
|
||||||
| `app/tools/shell.py` | `run_command` with command allow/deny lists and output truncation |
|
|
||||||
|
|
||||||
**All write/shell operations gated through permissions service.**
|
|
||||||
|
|
||||||
**Exit criteria:** Agent can autonomously create files, edit code via string replacement, and run shell commands — all with user approval for destructive operations.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 7 — Polish and Hardening
|
|
||||||
|
|
||||||
Production-readiness: error handling, resource limits, and documentation.
|
|
||||||
|
|
||||||
| Area | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| Error handling | Recovery from malformed tool calls, LLM errors, network timeouts in agent loop |
|
|
||||||
| Token budget | Conversation truncation or summarization when approaching context limit |
|
|
||||||
| Graceful shutdown | Clean Ctrl+C handling, session state preservation |
|
|
||||||
| Testing | End-to-end integration tests (`tests/integration/`), unit tests (`tests/unit/`) |
|
|
||||||
| Documentation | `README.md` with setup and usage instructions, `docs/tools.md` tool reference |
|
|
||||||
|
|
||||||
**Exit criteria:** Agent handles edge cases gracefully, tests pass, and a new user can set up and use the project from the README alone.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## File Coverage
|
|
||||||
|
|
||||||
Every file from the project structure in CLAUDE.md is accounted for:
|
|
||||||
|
|
||||||
| File | Phase |
|
|
||||||
|------|-------|
|
|
||||||
| `app/main.py` | 1, 2 |
|
|
||||||
| `app/models/config.py` | 1 |
|
|
||||||
| `app/models/message.py` | 1 |
|
|
||||||
| `app/models/tool_call.py` | 1 |
|
|
||||||
| `app/utils/logging.py` | 1 |
|
|
||||||
| `app/utils/display.py` | 1, 2 |
|
|
||||||
| `app/utils/file_helpers.py` | 1 |
|
|
||||||
| `app/utils/token_counter.py` | 1 |
|
|
||||||
| `app/agent/context.py` | 2 |
|
|
||||||
| `app/services/llm.py` | 3 |
|
|
||||||
| `app/services/streaming.py` | 3 |
|
|
||||||
| `app/tools/base.py` | 4 |
|
|
||||||
| `app/tools/registry.py` | 4 |
|
|
||||||
| `app/services/permissions.py` | 4 |
|
|
||||||
| `app/tools/filesystem.py` | 4, 6 |
|
|
||||||
| `app/tools/search.py` | 4 |
|
|
||||||
| `app/agent/loop.py` | 5 |
|
|
||||||
| `app/tools/edit.py` | 6 |
|
|
||||||
| `app/tools/shell.py` | 6 |
|
|
||||||
1802
docs/superpowers/plans/2026-03-11-textual-tui.md
Normal file
1802
docs/superpowers/plans/2026-03-11-textual-tui.md
Normal file
File diff suppressed because it is too large
Load Diff
192
docs/superpowers/specs/2026-03-11-textual-tui-design.md
Normal file
192
docs/superpowers/specs/2026-03-11-textual-tui-design.md
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
# Textual TUI Redesign — Design Spec
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Replace the current sequential print-and-scroll terminal UI with a full persistent split-screen TUI using Textual. Input is pinned at the bottom, scrollable message history above, with a header showing app/model info and a footer showing token usage and iteration count.
|
||||||
|
|
||||||
|
## Layout
|
||||||
|
|
||||||
|
```
|
||||||
|
+------------------- Header --------------------+
|
||||||
|
| SneakyCode qwen2.5-coder:32b |
|
||||||
|
+-----------------------------------------------+
|
||||||
|
| |
|
||||||
|
| +--- You ---+ |
|
||||||
|
| | prompt | <- RichLog widget |
|
||||||
|
| +-----------+ (handles own scrolling) |
|
||||||
|
| |
|
||||||
|
| Thinking... |
|
||||||
|
| |
|
||||||
|
| +-- Assistant --+ |
|
||||||
|
| | response... | |
|
||||||
|
| +---------------+ |
|
||||||
|
| |
|
||||||
|
| > read_file README.md -- 148 lines, 5128 ch |
|
||||||
|
| > grep_files "pattern" -- 3 matches |
|
||||||
|
| |
|
||||||
|
+-----------------------------------------------+
|
||||||
|
| Tokens: ~1,511 / 32,000 | Iteration 5/25 | <- StatusBar
|
||||||
|
+-----------------------------------------------+
|
||||||
|
| > [input cursor] | <- Input widget
|
||||||
|
+-----------------------------------------------+
|
||||||
|
```
|
||||||
|
|
||||||
|
**Widget hierarchy (no VerticalScroll wrapper — RichLog handles its own scrolling):**
|
||||||
|
- `Header` — Textual built-in, title="SneakyCode", subtitle=model name
|
||||||
|
- `RichLog` (id="chat-log") — main scroll area, accepts Rich renderables via `.write()`
|
||||||
|
- `StreamingStatic` — persistent hidden `Static` widget, shown/hidden during streaming (avoids mount/unmount overhead)
|
||||||
|
- `StatusBar` — custom `Static` widget, 1 row, docked above Input
|
||||||
|
- `Input` — Textual built-in, pinned at bottom
|
||||||
|
|
||||||
|
## New Files
|
||||||
|
|
||||||
|
### `app/ui/app.py` — Textual App
|
||||||
|
|
||||||
|
SneakyCodeApp subclasses `textual.app.App`. Responsibilities:
|
||||||
|
|
||||||
|
- `compose()` yields: Header, RichLog(id="chat-log"), StreamingStatic(id="streaming"), StatusBar(id="status"), Input
|
||||||
|
- `on_input_submitted()` handler: reads input value, clears input, writes user panel to chat log, dispatches agent turn as a worker
|
||||||
|
- Agent turn runs via `run_worker()` (async worker, NOT threaded) so the UI stays responsive. Since the worker is async and on the event loop, widget methods can be called directly — no `call_from_thread()` needed.
|
||||||
|
- Slash commands (/quit, /history, /clear, /save, /session) parsed from input before dispatching to agent
|
||||||
|
- Holds references to config, SessionContext, AgentLoop (created in `on_mount`)
|
||||||
|
- Header subtitle set to model name from config
|
||||||
|
- `on_worker_state_changed()` handler: catches worker errors and writes error panels to RichLog
|
||||||
|
- Ctrl+C binding: cancels the running agent worker (does NOT quit the app). A second Ctrl+C or `/quit` exits.
|
||||||
|
|
||||||
|
### `app/ui/widgets.py` — Custom Widgets
|
||||||
|
|
||||||
|
**StatusBar** — A simple `Static` widget styled as a footer bar. Displays token usage and iteration count. Updated by the agent loop after each LLM step via `status_bar.update(renderable)`.
|
||||||
|
|
||||||
|
**StreamingStatic** — A `Static` widget that stays mounted but hidden. During streaming, it becomes visible and receives `update()` calls with partial content. When streaming ends, it is hidden and its content is cleared. This avoids the overhead of mounting/unmounting on every LLM response.
|
||||||
|
|
||||||
|
### `app/ui/styles.tcss` — Textual CSS
|
||||||
|
|
||||||
|
Layout rules:
|
||||||
|
- RichLog fills available height (fraction-based sizing, e.g. `height: 1fr`)
|
||||||
|
- StreamingStatic: `display: none` by default, shown during streaming
|
||||||
|
- StatusBar is 1 row, docked bottom above Input
|
||||||
|
- Input is 1 row, docked at very bottom
|
||||||
|
- Color scheme matches existing SNEAKYCODE_THEME (cyan for user, green for assistant, magenta for tools, dim for metadata)
|
||||||
|
|
||||||
|
## Modified Files
|
||||||
|
|
||||||
|
### `app/main.py`
|
||||||
|
|
||||||
|
- Remove `_run_repl()` async function entirely
|
||||||
|
- Remove `console.input()` usage
|
||||||
|
- `main()` creates config, runs preflight via `asyncio.run(_preflight(config))` (before Textual starts — this is fine, separate event loop), then instantiates and runs `SneakyCodeApp(config).run()`
|
||||||
|
- CLI arg parsing stays (--config, -v, --log-file)
|
||||||
|
- Session resume: `_offer_session_resume()` moves into `SneakyCodeApp.on_mount()` — instead of `console.input()`, push a modal screen asking "Resume previous session? [y/n]" with button/key handlers
|
||||||
|
- Auto-save: triggers after each agent turn completes (in the worker completion handler)
|
||||||
|
- SIGTERM handler: removed — Textual manages its own signal handling and shutdown lifecycle
|
||||||
|
|
||||||
|
### `app/services/streaming.py`
|
||||||
|
|
||||||
|
- Remove `from rich.live import Live` and `from rich.spinner import Spinner`
|
||||||
|
- `process_stream()` no longer creates a `Rich.Live` context
|
||||||
|
- Instead, accepts callback parameters:
|
||||||
|
- `on_content: Callable[[str], None]` — called with accumulated content on each content chunk
|
||||||
|
- `on_thinking: Callable[[], None]` — called once when first reasoning token arrives
|
||||||
|
- `on_done: Callable[[], None]` — called when streaming completes
|
||||||
|
- **Throttling:** Content callback fires at most every 100ms (track last update time, skip intermediate chunks). Final content always fires on stream end.
|
||||||
|
- Since the agent runs as an async worker (on the event loop), callbacks can directly call widget methods — no `call_from_thread()` needed.
|
||||||
|
- All accumulation and tool-call parsing logic stays identical
|
||||||
|
|
||||||
|
### `app/utils/display.py`
|
||||||
|
|
||||||
|
- All `print_*` functions become `render_*` functions that return Rich renderables:
|
||||||
|
- `render_user_message(content) -> Panel`
|
||||||
|
- `render_assistant_message(content) -> Panel`
|
||||||
|
- `render_tool_call(name, args) -> Text`
|
||||||
|
- `render_tool_result(name, output, is_error) -> Text`
|
||||||
|
- `render_iteration_header(iteration, max_iter) -> Text`
|
||||||
|
- `render_warning(message) -> Text`
|
||||||
|
- `render_error(message) -> Text`
|
||||||
|
- `print_banner()` removed — Header widget replaces it
|
||||||
|
- `print_token_usage()` becomes `render_token_usage() -> Text` for the StatusBar
|
||||||
|
- `print_history()` becomes `render_history() -> Table` — written to RichLog, may need width constraints for narrow terminals
|
||||||
|
- A `DisplayAdapter` class wraps a `RichLog` reference and provides `write_user_message()`, `write_tool_call()`, etc. methods that call `render_*` then `rich_log.write()`
|
||||||
|
|
||||||
|
### `app/agent/loop.py`
|
||||||
|
|
||||||
|
- `AgentLoop.__init__()` accepts a `DisplayAdapter` instead of calling `display.py` print functions directly
|
||||||
|
- All display calls route through the adapter: `self._display.write_tool_call(name, args)`, `self._display.write_iteration_header(i, max)`, etc.
|
||||||
|
- `_execute_tool_calls()` becomes `async def _execute_tool_calls()` to support async permission checks
|
||||||
|
- The loop logic (ReAct pattern, retry, truncation) is unchanged
|
||||||
|
|
||||||
|
### `app/services/permissions.py`
|
||||||
|
|
||||||
|
- `PermissionsService.check()` becomes `async def check()`
|
||||||
|
- Instead of `rich.prompt.Confirm.ask()` (blocking stdin read), it:
|
||||||
|
1. Creates an `asyncio.Event`
|
||||||
|
2. Posts a custom message to the app requesting a permission modal
|
||||||
|
3. The app pushes a modal screen with the permission question and approve/deny buttons
|
||||||
|
4. When the user responds, the modal sets the event and stores the result
|
||||||
|
5. `check()` awaits the event and reads the result
|
||||||
|
- Edge cases: dismiss without choosing = deny. Ctrl+C during modal = deny. Focus returns to Input after modal dismisses.
|
||||||
|
|
||||||
|
### `app/utils/logging.py`
|
||||||
|
|
||||||
|
- **Critical change:** The shared `console = Console()` instance will corrupt the Textual display since Textual takes exclusive terminal control
|
||||||
|
- When running under Textual: disable `RichHandler` (console handler), keep only the file handler
|
||||||
|
- Add a `setup_logging_for_tui()` function that reconfigures logging to file-only mode
|
||||||
|
- Called from `SneakyCodeApp.on_mount()` before any agent work begins
|
||||||
|
- The `console` object still exists but should not be used for output during TUI mode — all output goes through the DisplayAdapter
|
||||||
|
- Consider: `--log-file` becomes required (or auto-set to a default) when running in TUI mode, so logs are not lost
|
||||||
|
|
||||||
|
## Unchanged Files
|
||||||
|
|
||||||
|
- `app/services/llm.py` — HTTP client, SSE parsing untouched
|
||||||
|
- `app/agent/context.py` — session state untouched
|
||||||
|
- `app/models/*` — all data models untouched
|
||||||
|
- `app/tools/*` — all tool implementations untouched
|
||||||
|
- `app/utils/file_helpers.py` — path safety untouched
|
||||||
|
- `app/utils/token_counter.py` — token counting untouched
|
||||||
|
|
||||||
|
## Key Patterns
|
||||||
|
|
||||||
|
### Streaming in Textual
|
||||||
|
|
||||||
|
The agent loop runs as an async worker (on the event loop, NOT threaded). During streaming:
|
||||||
|
|
||||||
|
1. App shows `StreamingStatic` widget, writes "Thinking..." initially
|
||||||
|
2. Worker calls `StreamHandler.process_stream(chunks, on_content=..., on_thinking=..., on_done=...)`
|
||||||
|
3. `on_content` callback: updates `StreamingStatic` with `Panel(Markdown(partial_content), title="Assistant", border_style="green")` — throttled to ~100ms intervals
|
||||||
|
4. `on_done` callback: hides `StreamingStatic`, writes final content to `RichLog` via `DisplayAdapter`
|
||||||
|
|
||||||
|
Since the worker is async (not threaded), callbacks run on the event loop and can call widget methods directly.
|
||||||
|
|
||||||
|
### Permission Prompts
|
||||||
|
|
||||||
|
1. Agent loop (in async worker) calls `await permissions.check(operation, details)`
|
||||||
|
2. `check()` creates an `asyncio.Event` and posts `PermissionRequest` message to the app
|
||||||
|
3. App handles `PermissionRequest`: pushes a modal screen with the question, approve/deny buttons
|
||||||
|
4. Modal screen: on button press, stores result and sets the event
|
||||||
|
5. `check()` awaits the event, reads result, returns approved/denied
|
||||||
|
6. Focus management: Input loses focus when modal appears, regains focus when modal dismisses
|
||||||
|
7. Default on dismiss/Ctrl+C: deny
|
||||||
|
|
||||||
|
### Cancellation
|
||||||
|
|
||||||
|
- Ctrl+C (first press): cancels the running agent worker via `worker.cancel()`. The agent loop should check for cancellation between iterations.
|
||||||
|
- Ctrl+C (second press) or `/quit`: exits the app via `app.exit()`
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
- Add `textual>=4.0.0` to pyproject.toml dependencies
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
|
||||||
|
1. Run the app — header shows app name + model, no console corruption
|
||||||
|
2. Type a prompt — user panel appears in scroll area, input clears
|
||||||
|
3. During LLM streaming — assistant response types out live (throttled) in the scroll area
|
||||||
|
4. Thinking indicator shows during reasoning-only phases
|
||||||
|
5. Tool calls appear as compact lines in the scroll area
|
||||||
|
6. Footer shows token usage and iteration count, updating each step
|
||||||
|
7. Scroll area auto-scrolls to bottom on new content
|
||||||
|
8. /quit, /clear, /history commands work from the input
|
||||||
|
9. Permission prompts show as modal, approve/deny work, focus returns to input
|
||||||
|
10. Ctrl+C cancels running agent turn without quitting
|
||||||
|
11. Worker errors display as error panels in the scroll area
|
||||||
|
12. Logging goes to file only — no console corruption
|
||||||
|
13. Session resume works on startup via modal dialog
|
||||||
240
docs/tools.md
Normal file
240
docs/tools.md
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
# Tool Reference
|
||||||
|
|
||||||
|
SneakyCode provides 11 agent-callable tools organized into 5 categories. All file path arguments must be **relative to the workspace root**.
|
||||||
|
|
||||||
|
## Permission Tiers
|
||||||
|
|
||||||
|
| Tier | Behavior | Tools |
|
||||||
|
|---------------|---------------------------------------|----------------------------------------------------------------|
|
||||||
|
| Auto-approved | Executed without user confirmation | `read_file`, `list_dir`, `grep_files`, `find_files`, `finish` |
|
||||||
|
| User confirm | Prompts user before execution | `write_file`, `make_dir`, `delete_file`, `str_replace`, `patch_apply`, `run_command` |
|
||||||
|
| Denied | Blocked entirely (configurable) | Any tool added to `permissions.deny` in config |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Read Tools
|
||||||
|
|
||||||
|
### read_file
|
||||||
|
|
||||||
|
Read the full contents of a text file.
|
||||||
|
|
||||||
|
| Parameter | Type | Required | Description |
|
||||||
|
|-------------|------|----------|------------------------------------------------|
|
||||||
|
| `file_path` | str | Yes | Path to the file to read (relative to workspace) |
|
||||||
|
|
||||||
|
**Permission:** Auto-approved
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```json
|
||||||
|
{"file_path": "app/main.py"}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Notes:** Binary files are detected and rejected. Files exceeding `max_file_size_bytes` (default 1 MB) are rejected.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### list_dir
|
||||||
|
|
||||||
|
List the contents of a directory. Directories are suffixed with `/`. Results are sorted with directories first, then files.
|
||||||
|
|
||||||
|
| Parameter | Type | Required | Default | Description |
|
||||||
|
|------------------|------|----------|---------|--------------------------------------|
|
||||||
|
| `directory_path` | str | No | `"."` | Path to directory (relative) |
|
||||||
|
| `recursive` | bool | No | `false` | If true, list entries recursively |
|
||||||
|
|
||||||
|
**Permission:** Auto-approved
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```json
|
||||||
|
{"directory_path": "app/tools", "recursive": true}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Search Tools
|
||||||
|
|
||||||
|
### grep_files
|
||||||
|
|
||||||
|
Search for a regex pattern in file contents. Returns matching lines with file paths and line numbers.
|
||||||
|
|
||||||
|
| Parameter | Type | Required | Default | Description |
|
||||||
|
|----------------|------------|----------|---------|------------------------------------------|
|
||||||
|
| `pattern` | str | Yes | | Regular expression pattern to search for |
|
||||||
|
| `path` | str | No | `"."` | Directory or file to search in |
|
||||||
|
| `file_pattern` | str\|null | No | `null` | Glob pattern to filter files (e.g. `*.py`) |
|
||||||
|
|
||||||
|
**Permission:** Auto-approved
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```json
|
||||||
|
{"pattern": "def main", "path": "app/", "file_pattern": "*.py"}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### find_files
|
||||||
|
|
||||||
|
Search for files matching a name pattern. Returns relative file paths.
|
||||||
|
|
||||||
|
| Parameter | Type | Required | Default | Description |
|
||||||
|
|-----------|------|----------|---------|---------------------------------------------------|
|
||||||
|
| `pattern` | str | Yes | | File name pattern (e.g. `*.py`, `config.yaml`) |
|
||||||
|
| `path` | str | No | `"."` | Directory to search in |
|
||||||
|
|
||||||
|
**Permission:** Auto-approved
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```json
|
||||||
|
{"pattern": "*.yaml", "path": "config/"}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Write Tools
|
||||||
|
|
||||||
|
### write_file
|
||||||
|
|
||||||
|
Write text content to a file. Creates parent directories if needed. Overwrites existing file content.
|
||||||
|
|
||||||
|
| Parameter | Type | Required | Description |
|
||||||
|
|-------------|------|----------|---------------------------------|
|
||||||
|
| `file_path` | str | Yes | Path to the file to write |
|
||||||
|
| `content` | str | Yes | Content to write to the file |
|
||||||
|
|
||||||
|
**Permission:** User confirmation required
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```json
|
||||||
|
{"file_path": "app/utils/helpers.py", "content": "def greet():\n return 'hello'\n"}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### make_dir
|
||||||
|
|
||||||
|
Create a directory and any necessary parent directories.
|
||||||
|
|
||||||
|
| Parameter | Type | Required | Description |
|
||||||
|
|------------------|------|----------|----------------------------------|
|
||||||
|
| `directory_path` | str | Yes | Path to the directory to create |
|
||||||
|
|
||||||
|
**Permission:** User confirmation required
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```json
|
||||||
|
{"directory_path": "app/services/new_module"}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### delete_file
|
||||||
|
|
||||||
|
Delete a single file. Does not delete directories.
|
||||||
|
|
||||||
|
| Parameter | Type | Required | Description |
|
||||||
|
|-------------|------|----------|---------------------------------|
|
||||||
|
| `file_path` | str | Yes | Path to the file to delete |
|
||||||
|
|
||||||
|
**Permission:** User confirmation required
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```json
|
||||||
|
{"file_path": "app/utils/deprecated.py"}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Edit Tools
|
||||||
|
|
||||||
|
### str_replace
|
||||||
|
|
||||||
|
Replace exactly one occurrence of `old_str` with `new_str` in a file. Fails if `old_str` is not found or appears more than once.
|
||||||
|
|
||||||
|
| Parameter | Type | Required | Description |
|
||||||
|
|-------------|------|----------|----------------------------------------------------|
|
||||||
|
| `file_path` | str | Yes | Path to the file to edit |
|
||||||
|
| `old_str` | str | Yes | The exact string to find and replace (must be unique) |
|
||||||
|
| `new_str` | str | Yes | The replacement string |
|
||||||
|
|
||||||
|
**Permission:** User confirmation required
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"file_path": "app/main.py",
|
||||||
|
"old_str": "def old_function():",
|
||||||
|
"new_str": "def new_function():"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### patch_apply
|
||||||
|
|
||||||
|
Apply a unified diff (patch) to a file. The patch must be in standard unified diff format.
|
||||||
|
|
||||||
|
| Parameter | Type | Required | Description |
|
||||||
|
|-------------|------|----------|--------------------------------------------|
|
||||||
|
| `file_path` | str | Yes | Path to the file to patch |
|
||||||
|
| `patch` | str | Yes | Unified diff format patch to apply |
|
||||||
|
|
||||||
|
**Permission:** User confirmation required
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"file_path": "app/main.py",
|
||||||
|
"patch": "--- a/app/main.py\n+++ b/app/main.py\n@@ -1,3 +1,3 @@\n-old line\n+new line\n"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Shell Tools
|
||||||
|
|
||||||
|
### run_command
|
||||||
|
|
||||||
|
Run a shell command in the workspace directory. Only allowed commands may be executed; dangerous commands are blocked.
|
||||||
|
|
||||||
|
| Parameter | Type | Required | Default | Description |
|
||||||
|
|-----------|----------|----------|---------|-----------------------------------|
|
||||||
|
| `command` | str | Yes | | Shell command to execute |
|
||||||
|
| `timeout` | int\|null | No | `30` | Timeout in seconds |
|
||||||
|
|
||||||
|
**Permission:** User confirmation required. Subject to `tools.shell.allowed_commands` and `tools.shell.denied_commands` in config.
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```json
|
||||||
|
{"command": "git status", "timeout": 10}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Notes:** Output is truncated to `max_output_bytes` (default 64 KB). The command's first word is checked against allow/deny lists.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Control Tools
|
||||||
|
|
||||||
|
### finish
|
||||||
|
|
||||||
|
Signal that the task is complete. Terminates the agent loop.
|
||||||
|
|
||||||
|
| Parameter | Type | Required | Default | Description |
|
||||||
|
|-----------|------|----------|--------------------|------------------------------|
|
||||||
|
| `message` | str | No | `"Task complete."` | Final message to the user |
|
||||||
|
|
||||||
|
**Permission:** Auto-approved
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```json
|
||||||
|
{"message": "Created the new module with tests."}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Security Notes
|
||||||
|
|
||||||
|
- All file paths are resolved against `workspace_root` with path traversal protection
|
||||||
|
- Binary file detection prevents reading/writing binary files
|
||||||
|
- File size limits prevent reading/writing excessively large files
|
||||||
|
- Shell commands are validated against configurable allow/deny lists
|
||||||
|
- Tool call arguments from the LLM are validated against JSON schema before execution
|
||||||
@@ -1 +1,12 @@
|
|||||||
Pressing up should cycle history like claude code.
|
# UI Issues
|
||||||
|
on /clear we need to reset the token counter in the header panel.
|
||||||
|
|
||||||
|
# Bugs
|
||||||
|
|
||||||
|
# Improvements
|
||||||
|
add -p to command line args so that the agent can run the prompt and return data directly via STDOUT
|
||||||
|
|
||||||
|
# Open questions:
|
||||||
|
How might we pass a directory to this app and have it use that directory as it's workspace so I don't have to copy files or do odd things to work in other directories.
|
||||||
|
|
||||||
|
How do we handle huge files not taking up so many tokens?
|
||||||
@@ -13,6 +13,7 @@ dependencies = [
|
|||||||
"pyyaml>=6.0",
|
"pyyaml>=6.0",
|
||||||
"httpx>=0.27",
|
"httpx>=0.27",
|
||||||
"structlog>=24.0",
|
"structlog>=24.0",
|
||||||
|
"textual>=4.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
150
tests/integration/conftest.py
Normal file
150
tests/integration/conftest.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""Shared fixtures for integration tests."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.config import AgentConfig, AppConfig, DisplayConfig, LLMConfig, PermissionsConfig, SessionConfig
|
||||||
|
from app.models.message import Message
|
||||||
|
from app.services.llm import LLMClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_workspace(tmp_path: Path) -> Path:
|
||||||
|
"""Create a temporary workspace directory with a sample file."""
|
||||||
|
ws = tmp_path / "workspace"
|
||||||
|
ws.mkdir()
|
||||||
|
(ws / "hello.txt").write_text("Hello, world!")
|
||||||
|
return ws
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_config(tmp_workspace: Path) -> AppConfig:
|
||||||
|
"""AppConfig suitable for integration tests."""
|
||||||
|
return AppConfig(
|
||||||
|
llm=LLMConfig(
|
||||||
|
model="test-model",
|
||||||
|
endpoint="http://localhost:11434",
|
||||||
|
max_retries=2,
|
||||||
|
retry_backoff_base=0.01,
|
||||||
|
retry_backoff_max=0.02,
|
||||||
|
),
|
||||||
|
agent=AgentConfig(
|
||||||
|
max_iterations=10,
|
||||||
|
max_conversation_tokens=32000,
|
||||||
|
workspace_root=tmp_workspace,
|
||||||
|
truncation_keep_recent=4,
|
||||||
|
truncation_threshold=0.85,
|
||||||
|
),
|
||||||
|
permissions=PermissionsConfig(
|
||||||
|
auto_approve=["read_file", "list_dir", "grep_files", "find_files", "finish"],
|
||||||
|
),
|
||||||
|
display=DisplayConfig(
|
||||||
|
show_tool_calls=False,
|
||||||
|
show_token_usage=False,
|
||||||
|
stream_output=False,
|
||||||
|
),
|
||||||
|
session=SessionConfig(
|
||||||
|
session_dir=tmp_workspace / ".sneakycode" / "sessions",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_text_chunks(content: str) -> list[dict[str, Any]]:
|
||||||
|
"""Create SSE chunk dicts for a plain text response."""
|
||||||
|
chunks = []
|
||||||
|
for char in content:
|
||||||
|
chunks.append({
|
||||||
|
"choices": [{"delta": {"content": char}, "index": 0}]
|
||||||
|
})
|
||||||
|
chunks.append({
|
||||||
|
"choices": [{"delta": {}, "finish_reason": "stop", "index": 0}],
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": len(content), "total_tokens": 10 + len(content)},
|
||||||
|
})
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def make_tool_call_chunks(name: str, args: dict[str, Any], tc_id: str = "call_001") -> list[dict[str, Any]]:
|
||||||
|
"""Create SSE chunk dicts for a tool call response."""
|
||||||
|
args_str = json.dumps(args)
|
||||||
|
chunks = [
|
||||||
|
{
|
||||||
|
"choices": [{
|
||||||
|
"delta": {
|
||||||
|
"tool_calls": [{
|
||||||
|
"index": 0,
|
||||||
|
"id": tc_id,
|
||||||
|
"function": {"name": name, "arguments": ""},
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"index": 0,
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [{
|
||||||
|
"delta": {
|
||||||
|
"tool_calls": [{
|
||||||
|
"index": 0,
|
||||||
|
"function": {"arguments": args_str},
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
"index": 0,
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [{"delta": {}, "finish_reason": "tool_calls", "index": 0}],
|
||||||
|
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
class MockLLMClient:
|
||||||
|
"""LLM client that returns scripted SSE chunk sequences."""
|
||||||
|
|
||||||
|
def __init__(self, responses: list[list[dict[str, Any]]]) -> None:
|
||||||
|
self._responses = list(responses)
|
||||||
|
self._call_count = 0
|
||||||
|
|
||||||
|
async def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> AsyncIterator[dict]:
|
||||||
|
if self._call_count >= len(self._responses):
|
||||||
|
raise RuntimeError("MockLLMClient ran out of scripted responses")
|
||||||
|
chunks = self._responses[self._call_count]
|
||||||
|
self._call_count += 1
|
||||||
|
for chunk in chunks:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def stream_chat_with_retry(
|
||||||
|
self,
|
||||||
|
messages: list[Message],
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
) -> AsyncIterator[dict]:
|
||||||
|
async for chunk in self.stream_chat(messages, tools=tools):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
@property
|
||||||
|
def call_count(self) -> int:
|
||||||
|
return self._call_count
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm_client():
|
||||||
|
"""Factory fixture for creating MockLLMClient instances."""
|
||||||
|
return MockLLMClient
|
||||||
155
tests/integration/test_agent_workflows.py
Normal file
155
tests/integration/test_agent_workflows.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
"""Integration tests for end-to-end agent workflows with mocked LLM."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agent.context import SessionContext
|
||||||
|
from app.agent.loop import AgentLoop
|
||||||
|
from app.models.config import AgentConfig, AppConfig, LLMConfig
|
||||||
|
from app.services.llm import LLMConnectionError
|
||||||
|
from app.services.permissions import PermissionsService
|
||||||
|
from app.services.session import SessionManager
|
||||||
|
from app.services.streaming import StreamHandler
|
||||||
|
from app.tools.registry import create_default_registry
|
||||||
|
|
||||||
|
from .conftest import MockLLMClient, make_text_chunks, make_tool_call_chunks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def agent_factory(test_config: AppConfig, tmp_workspace: Path):
|
||||||
|
"""Factory that creates an AgentLoop wired to a MockLLMClient."""
|
||||||
|
|
||||||
|
def create(responses):
|
||||||
|
ctx = SessionContext(test_config)
|
||||||
|
mock_client = MockLLMClient(responses)
|
||||||
|
handler = StreamHandler(test_config.display)
|
||||||
|
registry = create_default_registry(test_config.agent.workspace_root, test_config)
|
||||||
|
permissions = PermissionsService(test_config.permissions)
|
||||||
|
agent = AgentLoop(test_config, ctx, mock_client, handler, registry, permissions)
|
||||||
|
return agent, ctx, mock_client
|
||||||
|
|
||||||
|
return create
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentWorkflows:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multi_turn_read_workflow(self, agent_factory, tmp_workspace: Path) -> None:
|
||||||
|
"""Agent reads a file then responds with text — full 2-turn workflow."""
|
||||||
|
responses = [
|
||||||
|
make_tool_call_chunks("read_file", {"file_path": "hello.txt"}),
|
||||||
|
make_text_chunks("The file contains: Hello, world!"),
|
||||||
|
]
|
||||||
|
agent, ctx, mock_client = agent_factory(responses)
|
||||||
|
|
||||||
|
await agent.run_turn("What's in hello.txt?")
|
||||||
|
|
||||||
|
assert mock_client.call_count == 2
|
||||||
|
history = ctx.get_history()
|
||||||
|
# user, assistant (tool_call), tool (result), assistant (text)
|
||||||
|
assert len(history) == 4
|
||||||
|
assert history[0].role == "user"
|
||||||
|
assert history[1].role == "assistant"
|
||||||
|
assert history[1].tool_calls is not None
|
||||||
|
assert history[2].role == "tool"
|
||||||
|
assert "Hello, world!" in (history[2].content or "")
|
||||||
|
assert history[3].role == "assistant"
|
||||||
|
assert "Hello, world!" in (history[3].content or "")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_token_budget_truncation(self, test_config: AppConfig, tmp_workspace: Path) -> None:
|
||||||
|
"""When token budget is exceeded, truncation drops messages instead of stopping."""
|
||||||
|
# Use a tiny budget to trigger truncation
|
||||||
|
test_config.agent.max_conversation_tokens = 100
|
||||||
|
test_config.agent.truncation_keep_recent = 2
|
||||||
|
test_config.agent.truncation_threshold = 0.5
|
||||||
|
|
||||||
|
ctx = SessionContext(test_config)
|
||||||
|
|
||||||
|
# Fill history with enough to exceed budget
|
||||||
|
for i in range(10):
|
||||||
|
ctx.add_message("user", f"Message {i} " * 20)
|
||||||
|
ctx.add_message("assistant", f"Response {i} " * 20)
|
||||||
|
|
||||||
|
# Force token counter over budget
|
||||||
|
from app.utils.token_counter import TokenUsage
|
||||||
|
ctx.token_counter.count_usage(TokenUsage(total_tokens=100))
|
||||||
|
|
||||||
|
original_count = len(ctx.get_history())
|
||||||
|
dropped = ctx.truncate_history()
|
||||||
|
|
||||||
|
assert dropped > 0
|
||||||
|
assert len(ctx.get_history()) < original_count
|
||||||
|
# Recent messages should still be present
|
||||||
|
assert len(ctx.get_history()) >= 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_save_and_restore(self, test_config: AppConfig, tmp_workspace: Path) -> None:
|
||||||
|
"""Session can be saved after a turn and restored into a fresh context."""
|
||||||
|
responses = [make_text_chunks("Hello from the agent!")]
|
||||||
|
ctx = SessionContext(test_config)
|
||||||
|
mock_client = MockLLMClient(responses)
|
||||||
|
handler = StreamHandler(test_config.display)
|
||||||
|
registry = create_default_registry(test_config.agent.workspace_root, test_config)
|
||||||
|
permissions = PermissionsService(test_config.permissions)
|
||||||
|
agent = AgentLoop(test_config, ctx, mock_client, handler, registry, permissions)
|
||||||
|
|
||||||
|
await agent.run_turn("Hi")
|
||||||
|
|
||||||
|
# Save session
|
||||||
|
session_mgr = SessionManager(
|
||||||
|
test_config.session, test_config.agent.workspace_root, test_config.llm.model
|
||||||
|
)
|
||||||
|
path = session_mgr.save(ctx)
|
||||||
|
assert path.exists()
|
||||||
|
|
||||||
|
# Restore into fresh context
|
||||||
|
fresh_ctx = SessionContext(test_config)
|
||||||
|
loaded = session_mgr.load_latest()
|
||||||
|
assert loaded is not None
|
||||||
|
session_mgr.restore(loaded, fresh_ctx)
|
||||||
|
|
||||||
|
assert fresh_ctx.message_count == ctx.message_count
|
||||||
|
original_history = ctx.get_history()
|
||||||
|
restored_history = fresh_ctx.get_history()
|
||||||
|
for orig, restored in zip(original_history, restored_history):
|
||||||
|
assert orig.role == restored.role
|
||||||
|
assert orig.content == restored.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retry_on_transient_error(self, test_config: AppConfig, tmp_workspace: Path) -> None:
|
||||||
|
"""Agent recovers from transient LLM errors via retry."""
|
||||||
|
from app.services.llm import LLMClient
|
||||||
|
|
||||||
|
ctx = SessionContext(test_config)
|
||||||
|
|
||||||
|
# Create a real client but mock stream_chat to fail then succeed
|
||||||
|
client = LLMClient(test_config.llm)
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def flaky_stream(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
raise LLMConnectionError("Temporary failure")
|
||||||
|
for chunk in make_text_chunks("Recovered!"):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
client.stream_chat = flaky_stream # type: ignore[assignment]
|
||||||
|
|
||||||
|
handler = StreamHandler(test_config.display)
|
||||||
|
registry = create_default_registry(test_config.agent.workspace_root, test_config)
|
||||||
|
permissions = PermissionsService(test_config.permissions)
|
||||||
|
agent = AgentLoop(test_config, ctx, client, handler, registry, permissions)
|
||||||
|
|
||||||
|
with patch("app.services.llm.asyncio.sleep", new_callable=AsyncMock):
|
||||||
|
await agent.run_turn("Test retry")
|
||||||
|
|
||||||
|
history = ctx.get_history()
|
||||||
|
# Should have succeeded: user + assistant
|
||||||
|
assert len(history) == 2
|
||||||
|
assert history[1].content == "Recovered!"
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
await client.close()
|
||||||
@@ -21,6 +21,7 @@ from app.services.llm import LLMClient
|
|||||||
from app.services.permissions import PermissionsService
|
from app.services.permissions import PermissionsService
|
||||||
from app.services.streaming import StreamHandler
|
from app.services.streaming import StreamHandler
|
||||||
from app.tools.registry import ToolRegistry, create_default_registry
|
from app.tools.registry import ToolRegistry, create_default_registry
|
||||||
|
from app.utils.display import DisplayAdapter
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -62,6 +63,7 @@ def handler() -> MagicMock:
|
|||||||
mock.usage = None
|
mock.usage = None
|
||||||
mock.had_reasoning_only = False
|
mock.had_reasoning_only = False
|
||||||
mock.reset = MagicMock()
|
mock.reset = MagicMock()
|
||||||
|
mock.get_partial_message = MagicMock(return_value=None)
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
|
|
||||||
@@ -75,6 +77,11 @@ def permissions(config: AppConfig) -> PermissionsService:
|
|||||||
return PermissionsService(config.permissions)
|
return PermissionsService(config.permissions)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def display() -> MagicMock:
|
||||||
|
return MagicMock(spec=DisplayAdapter)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def agent(
|
def agent(
|
||||||
config: AppConfig,
|
config: AppConfig,
|
||||||
@@ -83,8 +90,9 @@ def agent(
|
|||||||
handler: MagicMock,
|
handler: MagicMock,
|
||||||
registry: ToolRegistry,
|
registry: ToolRegistry,
|
||||||
permissions: PermissionsService,
|
permissions: PermissionsService,
|
||||||
|
display: MagicMock,
|
||||||
) -> AgentLoop:
|
) -> AgentLoop:
|
||||||
return AgentLoop(config, ctx, client, handler, registry, permissions)
|
return AgentLoop(config, ctx, client, handler, registry, permissions, display)
|
||||||
|
|
||||||
|
|
||||||
def _make_text_message(content: str) -> Message:
|
def _make_text_message(content: str) -> Message:
|
||||||
|
|||||||
87
tests/unit/test_display.py
Normal file
87
tests/unit/test_display.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""Tests for display render functions and DisplayAdapter."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
from app.models.message import Message
|
||||||
|
from app.utils.display import (
|
||||||
|
DisplayAdapter,
|
||||||
|
render_assistant_message,
|
||||||
|
render_error,
|
||||||
|
render_iteration_header,
|
||||||
|
render_tool_call,
|
||||||
|
render_tool_result,
|
||||||
|
render_token_usage,
|
||||||
|
render_user_message,
|
||||||
|
render_warning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRenderFunctions:
|
||||||
|
def test_render_user_message_returns_text(self) -> None:
|
||||||
|
result = render_user_message("hello")
|
||||||
|
assert isinstance(result, Text)
|
||||||
|
assert "hello" in result.plain
|
||||||
|
|
||||||
|
def test_render_assistant_message_returns_panel(self) -> None:
|
||||||
|
result = render_assistant_message("response")
|
||||||
|
assert isinstance(result, Panel)
|
||||||
|
assert result.title == "Assistant"
|
||||||
|
|
||||||
|
def test_render_tool_call_returns_text(self) -> None:
|
||||||
|
result = render_tool_call("read_file", '{"path": "foo.py"}')
|
||||||
|
assert isinstance(result, Text)
|
||||||
|
assert "read_file" in result.plain
|
||||||
|
|
||||||
|
def test_render_tool_result_success(self) -> None:
|
||||||
|
result = render_tool_result("read_file", "file contents here", is_error=False)
|
||||||
|
assert isinstance(result, Text)
|
||||||
|
assert "read_file" in result.plain
|
||||||
|
|
||||||
|
def test_render_tool_result_error(self) -> None:
|
||||||
|
result = render_tool_result("read_file", "not found", is_error=True)
|
||||||
|
assert isinstance(result, Text)
|
||||||
|
assert "read_file" in result.plain
|
||||||
|
|
||||||
|
def test_render_iteration_header(self) -> None:
|
||||||
|
result = render_iteration_header(3, 25)
|
||||||
|
assert isinstance(result, Text)
|
||||||
|
assert "3/25" in result.plain
|
||||||
|
|
||||||
|
def test_render_token_usage(self) -> None:
|
||||||
|
result = render_token_usage(1500, 32000)
|
||||||
|
assert isinstance(result, Text)
|
||||||
|
assert "1,500" in result.plain
|
||||||
|
|
||||||
|
def test_render_warning(self) -> None:
|
||||||
|
result = render_warning("something happened")
|
||||||
|
assert isinstance(result, Text)
|
||||||
|
|
||||||
|
def test_render_error(self) -> None:
|
||||||
|
result = render_error("bad thing")
|
||||||
|
assert isinstance(result, Text)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDisplayAdapter:
|
||||||
|
def test_write_user_message(self) -> None:
|
||||||
|
mock_log = MagicMock()
|
||||||
|
adapter = DisplayAdapter(mock_log)
|
||||||
|
adapter.write_user_message("hello")
|
||||||
|
mock_log.write.assert_called_once()
|
||||||
|
arg = mock_log.write.call_args[0][0]
|
||||||
|
assert isinstance(arg, Text)
|
||||||
|
|
||||||
|
def test_write_tool_call(self) -> None:
|
||||||
|
mock_log = MagicMock()
|
||||||
|
adapter = DisplayAdapter(mock_log)
|
||||||
|
adapter.write_tool_call("read_file", '{"path": "x"}')
|
||||||
|
mock_log.write.assert_called_once()
|
||||||
|
|
||||||
|
def test_write_warning(self) -> None:
|
||||||
|
mock_log = MagicMock()
|
||||||
|
adapter = DisplayAdapter(mock_log)
|
||||||
|
adapter.write_warning("oops")
|
||||||
|
mock_log.write.assert_called_once()
|
||||||
314
tests/unit/test_file_cache.py
Normal file
314
tests/unit/test_file_cache.py
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
"""Tests for the file cache with LRU eviction and mtime invalidation."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.config import AppConfig, load_config
|
||||||
|
from app.models.tool_call import ToolResultStatus
|
||||||
|
from app.tools.filesystem import ReadFileTool, ReadManyFilesTool
|
||||||
|
from app.utils.file_cache import CacheStats, FileCache, cached_read_file
|
||||||
|
from app.utils.file_helpers import BinaryFileError, FileSizeError, PathSecurityError
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# FileCache unit tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFileCache:
|
||||||
|
def test_put_and_get_roundtrip(self, tmp_path: Path) -> None:
|
||||||
|
cache = FileCache()
|
||||||
|
f = tmp_path / "hello.txt"
|
||||||
|
f.write_text("hello world")
|
||||||
|
|
||||||
|
cache.put(f, "hello world")
|
||||||
|
assert cache.get(f) == "hello world"
|
||||||
|
|
||||||
|
def test_get_returns_none_for_missing_key(self, tmp_path: Path) -> None:
|
||||||
|
cache = FileCache()
|
||||||
|
assert cache.get(tmp_path / "nope.txt") is None
|
||||||
|
|
||||||
|
def test_mtime_change_causes_miss(self, tmp_path: Path) -> None:
|
||||||
|
cache = FileCache()
|
||||||
|
f = tmp_path / "data.txt"
|
||||||
|
f.write_text("v1")
|
||||||
|
cache.put(f, "v1")
|
||||||
|
|
||||||
|
# Mutate the file so mtime changes
|
||||||
|
time.sleep(0.05) # ensure mtime differs
|
||||||
|
f.write_text("v2")
|
||||||
|
|
||||||
|
assert cache.get(f) is None # stale → miss
|
||||||
|
assert cache.stats.invalidations == 1
|
||||||
|
|
||||||
|
def test_lru_eviction_at_capacity(self, tmp_path: Path) -> None:
|
||||||
|
cache = FileCache(max_entries=3)
|
||||||
|
files = []
|
||||||
|
for i in range(4):
|
||||||
|
f = tmp_path / f"f{i}.txt"
|
||||||
|
f.write_text(f"content-{i}")
|
||||||
|
files.append(f)
|
||||||
|
|
||||||
|
# Fill cache to capacity
|
||||||
|
for f in files[:3]:
|
||||||
|
cache.put(f, f.read_text())
|
||||||
|
assert len(cache) == 3
|
||||||
|
|
||||||
|
# Adding a 4th evicts the LRU (files[0])
|
||||||
|
cache.put(files[3], files[3].read_text())
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert cache.get(files[0]) is None # evicted
|
||||||
|
assert cache.stats.evictions == 1
|
||||||
|
|
||||||
|
# files[1..3] still present
|
||||||
|
for f in files[1:]:
|
||||||
|
assert cache.get(f) is not None
|
||||||
|
|
||||||
|
def test_invalidate_removes_entry(self, tmp_path: Path) -> None:
|
||||||
|
cache = FileCache()
|
||||||
|
f = tmp_path / "rm.txt"
|
||||||
|
f.write_text("bye")
|
||||||
|
cache.put(f, "bye")
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache.invalidate(f)
|
||||||
|
assert len(cache) == 0
|
||||||
|
assert cache.get(f) is None
|
||||||
|
assert cache.stats.invalidations == 1
|
||||||
|
|
||||||
|
def test_invalidate_noop_for_missing(self) -> None:
|
||||||
|
cache = FileCache()
|
||||||
|
cache.invalidate(Path("/nonexistent"))
|
||||||
|
assert cache.stats.invalidations == 0
|
||||||
|
|
||||||
|
def test_clear_empties_cache(self, tmp_path: Path) -> None:
|
||||||
|
cache = FileCache()
|
||||||
|
for i in range(5):
|
||||||
|
f = tmp_path / f"c{i}.txt"
|
||||||
|
f.write_text(str(i))
|
||||||
|
cache.put(f, str(i))
|
||||||
|
assert len(cache) == 5
|
||||||
|
|
||||||
|
cache.clear()
|
||||||
|
assert len(cache) == 0
|
||||||
|
|
||||||
|
def test_stats_accuracy(self, tmp_path: Path) -> None:
|
||||||
|
cache = FileCache(max_entries=2)
|
||||||
|
a = tmp_path / "a.txt"
|
||||||
|
b = tmp_path / "b.txt"
|
||||||
|
c = tmp_path / "c.txt"
|
||||||
|
a.write_text("a")
|
||||||
|
b.write_text("b")
|
||||||
|
c.write_text("c")
|
||||||
|
|
||||||
|
# Miss
|
||||||
|
cache.get(a)
|
||||||
|
assert cache.stats.misses == 1
|
||||||
|
assert cache.stats.hits == 0
|
||||||
|
|
||||||
|
# Put + hit
|
||||||
|
cache.put(a, "a")
|
||||||
|
cache.get(a)
|
||||||
|
assert cache.stats.hits == 1
|
||||||
|
|
||||||
|
# Fill + evict
|
||||||
|
cache.put(b, "b")
|
||||||
|
cache.put(c, "c") # evicts a
|
||||||
|
assert cache.stats.evictions == 1
|
||||||
|
|
||||||
|
def test_hit_rate(self) -> None:
|
||||||
|
stats = CacheStats(hits=3, misses=1)
|
||||||
|
assert stats.hit_rate == pytest.approx(0.75)
|
||||||
|
|
||||||
|
def test_hit_rate_zero_total(self) -> None:
|
||||||
|
stats = CacheStats()
|
||||||
|
assert stats.hit_rate == 0.0
|
||||||
|
|
||||||
|
def test_file_deleted_after_caching(self, tmp_path: Path) -> None:
|
||||||
|
cache = FileCache()
|
||||||
|
f = tmp_path / "gone.txt"
|
||||||
|
f.write_text("here")
|
||||||
|
cache.put(f, "here")
|
||||||
|
|
||||||
|
f.unlink()
|
||||||
|
assert cache.get(f) is None # stat fails → miss
|
||||||
|
|
||||||
|
def test_put_skips_when_stat_fails(self) -> None:
|
||||||
|
cache = FileCache()
|
||||||
|
cache.put(Path("/totally/nonexistent"), "data")
|
||||||
|
assert len(cache) == 0
|
||||||
|
|
||||||
|
def test_get_moves_to_end(self, tmp_path: Path) -> None:
|
||||||
|
"""Accessing an entry makes it most-recently-used, protecting from eviction."""
|
||||||
|
cache = FileCache(max_entries=3)
|
||||||
|
files = []
|
||||||
|
for i in range(3):
|
||||||
|
f = tmp_path / f"lru{i}.txt"
|
||||||
|
f.write_text(f"c{i}")
|
||||||
|
files.append(f)
|
||||||
|
cache.put(f, f"c{i}")
|
||||||
|
|
||||||
|
# Touch files[0] to make it MRU
|
||||||
|
cache.get(files[0])
|
||||||
|
|
||||||
|
# Add a new entry — files[1] (LRU) should be evicted, not files[0]
|
||||||
|
extra = tmp_path / "extra.txt"
|
||||||
|
extra.write_text("x")
|
||||||
|
cache.put(extra, "x")
|
||||||
|
|
||||||
|
assert cache.get(files[0]) is not None # protected by access
|
||||||
|
assert cache.get(files[1]) is None # evicted
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# cached_read_file tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCachedReadFile:
|
||||||
|
def test_without_cache_matches_safe_read(self, tmp_path: Path) -> None:
|
||||||
|
f = tmp_path / "plain.txt"
|
||||||
|
f.write_text("hello")
|
||||||
|
content = cached_read_file(f, tmp_path, cache=None)
|
||||||
|
assert content == "hello"
|
||||||
|
|
||||||
|
def test_populates_on_miss_returns_on_hit(self, tmp_path: Path) -> None:
|
||||||
|
cache = FileCache()
|
||||||
|
f = tmp_path / "cached.txt"
|
||||||
|
f.write_text("data")
|
||||||
|
|
||||||
|
# First call: miss → read from disk → populate cache
|
||||||
|
content1 = cached_read_file(f, tmp_path, cache=cache)
|
||||||
|
assert content1 == "data"
|
||||||
|
assert cache.stats.misses == 1
|
||||||
|
assert cache.stats.hits == 0
|
||||||
|
|
||||||
|
# Second call: hit → from cache
|
||||||
|
content2 = cached_read_file(f, tmp_path, cache=cache)
|
||||||
|
assert content2 == "data"
|
||||||
|
assert cache.stats.hits == 1
|
||||||
|
|
||||||
|
def test_security_checks_run_on_cached_path(self, tmp_path: Path) -> None:
|
||||||
|
cache = FileCache()
|
||||||
|
with pytest.raises(PathSecurityError):
|
||||||
|
cached_read_file("/etc/passwd", tmp_path, cache=cache)
|
||||||
|
|
||||||
|
def test_binary_check_runs_on_cached_path(self, tmp_path: Path) -> None:
|
||||||
|
cache = FileCache()
|
||||||
|
f = tmp_path / "bin.dat"
|
||||||
|
f.write_bytes(b"\x00binary\x00")
|
||||||
|
with pytest.raises(BinaryFileError):
|
||||||
|
cached_read_file(f, tmp_path, cache=cache)
|
||||||
|
|
||||||
|
def test_size_check_runs_on_cached_path(self, tmp_path: Path) -> None:
|
||||||
|
cache = FileCache()
|
||||||
|
f = tmp_path / "big.txt"
|
||||||
|
f.write_text("x" * 200)
|
||||||
|
|
||||||
|
# First read populates cache
|
||||||
|
cached_read_file(f, tmp_path, max_size_bytes=1000, cache=cache)
|
||||||
|
|
||||||
|
# Now make file too big on disk — security check should catch it
|
||||||
|
# even though content is cached
|
||||||
|
f.write_text("x" * 2000)
|
||||||
|
with pytest.raises(FileSizeError):
|
||||||
|
cached_read_file(f, tmp_path, max_size_bytes=1000, cache=cache)
|
||||||
|
|
||||||
|
def test_file_not_found(self, tmp_path: Path) -> None:
|
||||||
|
cache = FileCache()
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
cached_read_file(tmp_path / "nope.txt", tmp_path, cache=cache)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tool-level cache-hit dedup tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config() -> AppConfig:
|
||||||
|
return load_config()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_workspace(tmp_path: Path, config: AppConfig) -> tuple[Path, AppConfig]:
|
||||||
|
config.agent.workspace_root = tmp_path
|
||||||
|
return tmp_path, config
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadFileToolCacheHit:
|
||||||
|
def test_first_read_returns_full_content(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
cache = FileCache()
|
||||||
|
(ws / "hello.txt").write_text("hello world")
|
||||||
|
|
||||||
|
tool = ReadFileTool(ws, cfg, file_cache=cache)
|
||||||
|
result = tool.run("tc-1", {"file_path": "hello.txt"})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert result.output == "hello world"
|
||||||
|
|
||||||
|
def test_second_read_returns_cached_message(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
cache = FileCache()
|
||||||
|
(ws / "hello.txt").write_text("hello world")
|
||||||
|
|
||||||
|
tool = ReadFileTool(ws, cfg, file_cache=cache)
|
||||||
|
tool.run("tc-1", {"file_path": "hello.txt"})
|
||||||
|
|
||||||
|
result2 = tool.run("tc-2", {"file_path": "hello.txt"})
|
||||||
|
assert result2.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "[Cached]" in result2.output
|
||||||
|
assert "hello.txt" in result2.output
|
||||||
|
assert "hello world" not in result2.output
|
||||||
|
|
||||||
|
def test_changed_file_returns_full_content_again(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
cache = FileCache()
|
||||||
|
f = ws / "data.txt"
|
||||||
|
f.write_text("v1")
|
||||||
|
|
||||||
|
tool = ReadFileTool(ws, cfg, file_cache=cache)
|
||||||
|
tool.run("tc-1", {"file_path": "data.txt"})
|
||||||
|
|
||||||
|
# Mutate file so mtime changes
|
||||||
|
time.sleep(0.05)
|
||||||
|
f.write_text("v2")
|
||||||
|
|
||||||
|
result2 = tool.run("tc-2", {"file_path": "data.txt"})
|
||||||
|
assert result2.status == ToolResultStatus.SUCCESS
|
||||||
|
assert result2.output == "v2"
|
||||||
|
assert "[Cached]" not in result2.output
|
||||||
|
|
||||||
|
def test_no_cache_always_returns_content(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "hello.txt").write_text("hello")
|
||||||
|
|
||||||
|
tool = ReadFileTool(ws, cfg, file_cache=None)
|
||||||
|
r1 = tool.run("tc-1", {"file_path": "hello.txt"})
|
||||||
|
r2 = tool.run("tc-2", {"file_path": "hello.txt"})
|
||||||
|
assert r1.output == "hello"
|
||||||
|
assert r2.output == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadManyFilesToolCacheHit:
|
||||||
|
def test_cached_files_get_short_message(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
cache = FileCache()
|
||||||
|
(ws / "a.txt").write_text("alpha")
|
||||||
|
(ws / "b.txt").write_text("bravo")
|
||||||
|
|
||||||
|
tool = ReadManyFilesTool(ws, cfg, file_cache=cache)
|
||||||
|
|
||||||
|
# First read — full content
|
||||||
|
r1 = tool.run("tc-1", {"file_paths": ["a.txt", "b.txt"]})
|
||||||
|
assert "alpha" in r1.output
|
||||||
|
assert "bravo" in r1.output
|
||||||
|
|
||||||
|
# Second read — cached messages
|
||||||
|
r2 = tool.run("tc-2", {"file_paths": ["a.txt", "b.txt"]})
|
||||||
|
assert "[Cached]" in r2.output
|
||||||
|
assert "alpha" not in r2.output
|
||||||
|
assert "bravo" not in r2.output
|
||||||
69
tests/unit/test_filesystem_read_many.py
Normal file
69
tests/unit/test_filesystem_read_many.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""Tests for the read_many_files tool."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.config import AppConfig, load_config
|
||||||
|
from app.models.tool_call import ToolResultStatus
|
||||||
|
from app.tools.filesystem import ReadManyFilesTool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config() -> AppConfig:
|
||||||
|
return load_config()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_workspace(tmp_path: Path, config: AppConfig) -> tuple[Path, AppConfig]:
|
||||||
|
"""Create a temporary workspace for read_many_files tests."""
|
||||||
|
config.agent.workspace_root = tmp_path
|
||||||
|
return tmp_path, config
|
||||||
|
|
||||||
|
|
||||||
|
class TestReadManyFilesTool:
|
||||||
|
def test_read_multiple_files(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "a.txt").write_text("alpha")
|
||||||
|
(ws / "b.txt").write_text("bravo")
|
||||||
|
tool = ReadManyFilesTool(ws, cfg)
|
||||||
|
result = tool.run("tc-1", {"file_paths": ["a.txt", "b.txt"]})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "=== a.txt ===" in result.output
|
||||||
|
assert "alpha" in result.output
|
||||||
|
assert "=== b.txt ===" in result.output
|
||||||
|
assert "bravo" in result.output
|
||||||
|
|
||||||
|
def test_partial_failure(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "exists.txt").write_text("hello")
|
||||||
|
tool = ReadManyFilesTool(ws, cfg)
|
||||||
|
result = tool.run("tc-2", {"file_paths": ["exists.txt", "missing.txt"]})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "hello" in result.output
|
||||||
|
assert "[ERROR]" in result.output
|
||||||
|
assert "=== missing.txt ===" in result.output
|
||||||
|
|
||||||
|
def test_all_files_fail(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = ReadManyFilesTool(ws, cfg)
|
||||||
|
result = tool.run("tc-3", {"file_paths": ["no1.txt", "no2.txt"]})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "All files failed" in (result.error or "")
|
||||||
|
|
||||||
|
def test_empty_file_paths(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
tool = ReadManyFilesTool(ws, cfg)
|
||||||
|
result = tool.run("tc-4", {"file_paths": []})
|
||||||
|
assert result.status == ToolResultStatus.ERROR
|
||||||
|
assert "empty" in (result.error or "").lower()
|
||||||
|
|
||||||
|
def test_path_security_inline_error(self, tmp_workspace: tuple[Path, AppConfig]) -> None:
|
||||||
|
ws, cfg = tmp_workspace
|
||||||
|
(ws / "safe.txt").write_text("ok")
|
||||||
|
tool = ReadManyFilesTool(ws, cfg)
|
||||||
|
result = tool.run("tc-5", {"file_paths": ["safe.txt", "../../etc/passwd"]})
|
||||||
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
|
assert "ok" in result.output
|
||||||
|
assert "[ERROR]" in result.output
|
||||||
|
assert "outside" in result.output.lower()
|
||||||
31
tests/unit/test_logging_tui.py
Normal file
31
tests/unit/test_logging_tui.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""Tests for TUI-safe logging configuration."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from app.utils.logging import setup_logging, setup_logging_for_tui
|
||||||
|
|
||||||
|
|
||||||
|
class TestTuiLogging:
|
||||||
|
def test_setup_logging_for_tui_removes_rich_handler(self) -> None:
|
||||||
|
"""TUI mode should have no RichHandler on root logger."""
|
||||||
|
from rich.logging import RichHandler
|
||||||
|
|
||||||
|
# First set up normal logging (adds RichHandler)
|
||||||
|
setup_logging()
|
||||||
|
root = logging.getLogger()
|
||||||
|
assert any(isinstance(h, RichHandler) for h in root.handlers)
|
||||||
|
|
||||||
|
# Switch to TUI mode
|
||||||
|
setup_logging_for_tui()
|
||||||
|
root = logging.getLogger()
|
||||||
|
assert not any(isinstance(h, RichHandler) for h in root.handlers)
|
||||||
|
|
||||||
|
def test_setup_logging_for_tui_keeps_file_handler(self, tmp_path) -> None:
|
||||||
|
"""TUI mode should preserve file handler if configured."""
|
||||||
|
log_file = tmp_path / "test.log"
|
||||||
|
setup_logging(log_file=log_file)
|
||||||
|
setup_logging_for_tui()
|
||||||
|
|
||||||
|
root = logging.getLogger()
|
||||||
|
file_handlers = [h for h in root.handlers if isinstance(h, logging.FileHandler)]
|
||||||
|
assert len(file_handlers) == 1
|
||||||
51
tests/unit/test_permissions.py
Normal file
51
tests/unit/test_permissions.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""Tests for async PermissionsService."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.config import PermissionsConfig
|
||||||
|
from app.services.permissions import PermissionsService
|
||||||
|
|
||||||
|
|
||||||
|
class TestPermissionsService:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_approve(self) -> None:
|
||||||
|
config = PermissionsConfig(auto_approve=["read_file"])
|
||||||
|
svc = PermissionsService(config)
|
||||||
|
assert await svc.check("read_file") is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deny(self) -> None:
|
||||||
|
config = PermissionsConfig(deny=["rm_file"])
|
||||||
|
svc = PermissionsService(config)
|
||||||
|
assert await svc.check("rm_file") is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_callback_approve(self) -> None:
|
||||||
|
config = PermissionsConfig()
|
||||||
|
svc = PermissionsService(config)
|
||||||
|
|
||||||
|
async def approve_callback(tool_name: str, description: str) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
svc.set_prompt_callback(approve_callback)
|
||||||
|
assert await svc.check("write_file", description="write something") is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_callback_deny(self) -> None:
|
||||||
|
config = PermissionsConfig()
|
||||||
|
svc = PermissionsService(config)
|
||||||
|
|
||||||
|
async def deny_callback(tool_name: str, description: str) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
svc.set_prompt_callback(deny_callback)
|
||||||
|
assert await svc.check("write_file") is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_callback_defaults_to_deny(self) -> None:
|
||||||
|
"""Without a callback set, unlisted tools are denied."""
|
||||||
|
config = PermissionsConfig()
|
||||||
|
svc = PermissionsService(config)
|
||||||
|
assert await svc.check("write_file") is False
|
||||||
125
tests/unit/test_retry.py
Normal file
125
tests/unit/test_retry.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Unit tests for LLM retry with exponential backoff."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.config import LLMConfig
|
||||||
|
from app.models.message import Message
|
||||||
|
from app.services.llm import LLMClient, LLMConnectionError, LLMResponseError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm_config() -> LLMConfig:
|
||||||
|
return LLMConfig(
|
||||||
|
model="test-model",
|
||||||
|
endpoint="http://localhost:11434",
|
||||||
|
max_retries=3,
|
||||||
|
retry_backoff_base=0.01,
|
||||||
|
retry_backoff_max=0.05,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(llm_config: LLMConfig) -> LLMClient:
|
||||||
|
return LLMClient(llm_config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def messages() -> list[Message]:
|
||||||
|
return [Message(role="user", content="Hello")]
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetry:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_succeeds_without_retry(self, client: LLMClient, messages: list[Message]) -> None:
|
||||||
|
"""Successful stream doesn't retry."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def fake_stream(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
yield {"choices": [{"delta": {"content": "Hi"}}]}
|
||||||
|
|
||||||
|
client.stream_chat = fake_stream # type: ignore[assignment]
|
||||||
|
|
||||||
|
collected = []
|
||||||
|
async for chunk in client.stream_chat_with_retry(messages):
|
||||||
|
collected.append(chunk)
|
||||||
|
|
||||||
|
assert len(collected) == 1
|
||||||
|
assert call_count == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retries_on_connection_error(self, client: LLMClient, messages: list[Message]) -> None:
|
||||||
|
"""Retries on LLMConnectionError, then succeeds."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def flaky_stream(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count < 3:
|
||||||
|
raise LLMConnectionError("Connection refused")
|
||||||
|
yield {"choices": [{"delta": {"content": "OK"}}]}
|
||||||
|
|
||||||
|
client.stream_chat = flaky_stream # type: ignore[assignment]
|
||||||
|
|
||||||
|
with patch("app.services.llm.asyncio.sleep", new_callable=AsyncMock):
|
||||||
|
collected = []
|
||||||
|
async for chunk in client.stream_chat_with_retry(messages):
|
||||||
|
collected.append(chunk)
|
||||||
|
|
||||||
|
assert len(collected) == 1
|
||||||
|
assert call_count == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retries_on_5xx(self, client: LLMClient, messages: list[Message]) -> None:
|
||||||
|
"""Retries on 5xx LLMResponseError."""
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def server_error_stream(*args, **kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count < 2:
|
||||||
|
raise LLMResponseError("Internal Server Error", status_code=500)
|
||||||
|
yield {"choices": [{"delta": {"content": "OK"}}]}
|
||||||
|
|
||||||
|
client.stream_chat = server_error_stream # type: ignore[assignment]
|
||||||
|
|
||||||
|
with patch("app.services.llm.asyncio.sleep", new_callable=AsyncMock):
|
||||||
|
collected = []
|
||||||
|
async for chunk in client.stream_chat_with_retry(messages):
|
||||||
|
collected.append(chunk)
|
||||||
|
|
||||||
|
assert len(collected) == 1
|
||||||
|
assert call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_retry_on_4xx(self, client: LLMClient, messages: list[Message]) -> None:
|
||||||
|
"""Does NOT retry on 4xx errors — raises immediately."""
|
||||||
|
async def bad_request_stream(*args, **kwargs):
|
||||||
|
raise LLMResponseError("Bad Request", status_code=400)
|
||||||
|
yield # pragma: no cover — make this an async generator
|
||||||
|
|
||||||
|
client.stream_chat = bad_request_stream # type: ignore[assignment]
|
||||||
|
|
||||||
|
with pytest.raises(LLMResponseError, match="Bad Request"):
|
||||||
|
async for _ in client.stream_chat_with_retry(messages):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_respects_max_retries(self, client: LLMClient, messages: list[Message]) -> None:
|
||||||
|
"""After exhausting retries, re-raises the last exception."""
|
||||||
|
async def always_fail(*args, **kwargs):
|
||||||
|
raise LLMConnectionError("Down forever")
|
||||||
|
yield # pragma: no cover
|
||||||
|
|
||||||
|
client.stream_chat = always_fail # type: ignore[assignment]
|
||||||
|
|
||||||
|
with patch("app.services.llm.asyncio.sleep", new_callable=AsyncMock) as mock_sleep:
|
||||||
|
with pytest.raises(LLMConnectionError, match="Down forever"):
|
||||||
|
async for _ in client.stream_chat_with_retry(messages):
|
||||||
|
pass # pragma: no cover
|
||||||
|
|
||||||
|
# Should have slept max_retries times (3 retries after initial attempt)
|
||||||
|
assert mock_sleep.call_count == 3
|
||||||
122
tests/unit/test_session.py
Normal file
122
tests/unit/test_session.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""Unit tests for session persistence."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agent.context import SessionContext
|
||||||
|
from app.models.config import AgentConfig, AppConfig, LLMConfig, SessionConfig
|
||||||
|
from app.services.session import SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_workspace(tmp_path: Path) -> Path:
|
||||||
|
return tmp_path / "workspace"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session_config(tmp_workspace: Path) -> SessionConfig:
|
||||||
|
return SessionConfig(
|
||||||
|
session_dir=tmp_workspace / ".sneakycode" / "sessions",
|
||||||
|
auto_save=True,
|
||||||
|
max_session_age_hours=72,
|
||||||
|
offer_resume=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config(tmp_workspace: Path) -> AppConfig:
|
||||||
|
return AppConfig(
|
||||||
|
llm=LLMConfig(model="test-model", endpoint="http://localhost:11434"),
|
||||||
|
agent=AgentConfig(workspace_root=tmp_workspace),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ctx(config: AppConfig) -> SessionContext:
|
||||||
|
return SessionContext(config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session_mgr(session_config: SessionConfig, tmp_workspace: Path) -> SessionManager:
|
||||||
|
tmp_workspace.mkdir(parents=True, exist_ok=True)
|
||||||
|
return SessionManager(session_config, tmp_workspace, "test-model")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionPersistence:
|
||||||
|
def test_save_creates_file(self, session_mgr: SessionManager, ctx: SessionContext, session_config: SessionConfig) -> None:
|
||||||
|
"""Saving a session creates a JSON file in the session directory."""
|
||||||
|
ctx.add_message("user", "Hello")
|
||||||
|
ctx.add_message("assistant", "Hi there!")
|
||||||
|
|
||||||
|
path = session_mgr.save(ctx)
|
||||||
|
|
||||||
|
assert path.exists()
|
||||||
|
assert path.suffix == ".json"
|
||||||
|
|
||||||
|
data = json.loads(path.read_text())
|
||||||
|
assert data["model"] == "test-model"
|
||||||
|
assert len(data["messages"]) == 2
|
||||||
|
|
||||||
|
def test_load_latest_returns_newest(self, session_mgr: SessionManager, ctx: SessionContext, session_config: SessionConfig) -> None:
|
||||||
|
"""load_latest returns the most recently modified session."""
|
||||||
|
ctx.add_message("user", "First session")
|
||||||
|
session_mgr.save(ctx)
|
||||||
|
|
||||||
|
# Create a second session manager (simulates a new startup)
|
||||||
|
mgr2 = SessionManager(session_config, session_mgr._workspace_root, "test-model")
|
||||||
|
ctx.add_message("assistant", "Second response")
|
||||||
|
path2 = mgr2.save(ctx)
|
||||||
|
|
||||||
|
loaded = mgr2.load_latest()
|
||||||
|
assert loaded is not None
|
||||||
|
assert loaded.session_id == mgr2._session_id
|
||||||
|
assert len(loaded.messages) == 2
|
||||||
|
|
||||||
|
def test_restore_populates_context(self, session_mgr: SessionManager, ctx: SessionContext, config: AppConfig) -> None:
|
||||||
|
"""Restoring a session populates the context with saved messages."""
|
||||||
|
ctx.add_message("user", "Hello")
|
||||||
|
ctx.add_message("assistant", "World")
|
||||||
|
session_mgr.save(ctx)
|
||||||
|
|
||||||
|
# Load and restore into a fresh context
|
||||||
|
fresh_ctx = SessionContext(config)
|
||||||
|
loaded = session_mgr.load_latest()
|
||||||
|
assert loaded is not None
|
||||||
|
|
||||||
|
session_mgr.restore(loaded, fresh_ctx)
|
||||||
|
|
||||||
|
history = fresh_ctx.get_history()
|
||||||
|
assert len(history) == 2
|
||||||
|
assert history[0].role == "user"
|
||||||
|
assert history[0].content == "Hello"
|
||||||
|
assert history[1].role == "assistant"
|
||||||
|
assert history[1].content == "World"
|
||||||
|
|
||||||
|
def test_cleanup_removes_old_files(self, session_config: SessionConfig, tmp_workspace: Path) -> None:
|
||||||
|
"""cleanup_old deletes files older than max_session_age_hours."""
|
||||||
|
tmp_workspace.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create session config with very short max age
|
||||||
|
short_config = SessionConfig(
|
||||||
|
session_dir=session_config.session_dir,
|
||||||
|
max_session_age_hours=0, # 0 hours = everything is old
|
||||||
|
)
|
||||||
|
mgr = SessionManager(short_config, tmp_workspace, "test-model")
|
||||||
|
|
||||||
|
# Create a session file manually with old timestamp
|
||||||
|
session_dir = tmp_workspace / short_config.session_dir
|
||||||
|
session_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
old_file = session_dir / "old_session.json"
|
||||||
|
old_file.write_text('{"version": 1}')
|
||||||
|
|
||||||
|
# Set mtime to the past
|
||||||
|
import os
|
||||||
|
old_time = time.time() - 3600 # 1 hour ago
|
||||||
|
os.utime(old_file, (old_time, old_time))
|
||||||
|
|
||||||
|
deleted = mgr.cleanup_old()
|
||||||
|
assert deleted == 1
|
||||||
|
assert not old_file.exists()
|
||||||
@@ -90,6 +90,6 @@ class TestRunCommandTool:
|
|||||||
# Create a file in the workspace to verify cwd
|
# Create a file in the workspace to verify cwd
|
||||||
(ws / "marker.txt").write_text("found")
|
(ws / "marker.txt").write_text("found")
|
||||||
tool = RunCommandTool(ws, cfg)
|
tool = RunCommandTool(ws, cfg)
|
||||||
result = tool.run("tc-9", {"command": "cat marker.txt"})
|
result = tool.run("tc-9", {"command": "head marker.txt"})
|
||||||
assert result.status == ToolResultStatus.SUCCESS
|
assert result.status == ToolResultStatus.SUCCESS
|
||||||
assert "found" in result.output
|
assert "found" in result.output
|
||||||
|
|||||||
111
tests/unit/test_streaming.py
Normal file
111
tests/unit/test_streaming.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
"""Tests for callback-based StreamHandler."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock, call
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.models.config import DisplayConfig
|
||||||
|
from app.services.streaming import StreamHandler
|
||||||
|
|
||||||
|
|
||||||
|
def _make_chunk(content: str | None = None, reasoning: str | None = None) -> dict:
|
||||||
|
"""Helper to create a fake SSE chunk."""
|
||||||
|
delta: dict = {}
|
||||||
|
if content is not None:
|
||||||
|
delta["content"] = content
|
||||||
|
if reasoning is not None:
|
||||||
|
delta["reasoning"] = reasoning
|
||||||
|
return {"choices": [{"delta": delta}]}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tool_call_chunk(index: int, tc_id: str = "", name: str = "", args: str = "") -> dict:
|
||||||
|
"""Helper to create a fake tool call chunk."""
|
||||||
|
tc_delta: dict = {"index": index}
|
||||||
|
if tc_id:
|
||||||
|
tc_delta["id"] = tc_id
|
||||||
|
func: dict = {}
|
||||||
|
if name:
|
||||||
|
func["name"] = name
|
||||||
|
if args:
|
||||||
|
func["arguments"] = args
|
||||||
|
if func:
|
||||||
|
tc_delta["function"] = func
|
||||||
|
return {"choices": [{"delta": {"tool_calls": [tc_delta]}}]}
|
||||||
|
|
||||||
|
|
||||||
|
async def _async_iter(items: list[dict]):
|
||||||
|
for item in items:
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
class TestStreamHandlerCallbacks:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_content_called_with_accumulated_text(self) -> None:
|
||||||
|
handler = StreamHandler(DisplayConfig(stream_output=True))
|
||||||
|
on_content = MagicMock()
|
||||||
|
on_thinking = MagicMock()
|
||||||
|
on_done = MagicMock()
|
||||||
|
handler.set_callbacks(on_content=on_content, on_thinking=on_thinking, on_done=on_done)
|
||||||
|
|
||||||
|
chunks = [_make_chunk(content="Hello"), _make_chunk(content=" world")]
|
||||||
|
msg = await handler.process_stream(_async_iter(chunks))
|
||||||
|
|
||||||
|
assert msg.content == "Hello world"
|
||||||
|
assert on_content.call_count >= 1
|
||||||
|
# Last call should have full accumulated content
|
||||||
|
last_content = on_content.call_args_list[-1][0][0]
|
||||||
|
assert last_content == "Hello world"
|
||||||
|
on_done.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_thinking_called_for_reasoning(self) -> None:
|
||||||
|
handler = StreamHandler(DisplayConfig(stream_output=True))
|
||||||
|
on_content = MagicMock()
|
||||||
|
on_thinking = MagicMock()
|
||||||
|
on_done = MagicMock()
|
||||||
|
handler.set_callbacks(on_content=on_content, on_thinking=on_thinking, on_done=on_done)
|
||||||
|
|
||||||
|
chunks = [_make_chunk(reasoning="let me think")]
|
||||||
|
msg = await handler.process_stream(_async_iter(chunks))
|
||||||
|
|
||||||
|
on_thinking.assert_called_once()
|
||||||
|
on_content.assert_not_called()
|
||||||
|
on_done.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_display_callbacks_skip_when_stream_output_disabled(self) -> None:
|
||||||
|
"""on_content and on_thinking are suppressed, but on_done always fires."""
|
||||||
|
handler = StreamHandler(DisplayConfig(stream_output=False))
|
||||||
|
on_content = MagicMock()
|
||||||
|
on_thinking = MagicMock()
|
||||||
|
on_done = MagicMock()
|
||||||
|
handler.set_callbacks(on_content=on_content, on_thinking=on_thinking, on_done=on_done)
|
||||||
|
|
||||||
|
chunks = [_make_chunk(content="Hello")]
|
||||||
|
msg = await handler.process_stream(_async_iter(chunks))
|
||||||
|
|
||||||
|
assert msg.content == "Hello"
|
||||||
|
on_content.assert_not_called()
|
||||||
|
on_thinking.assert_not_called()
|
||||||
|
on_done.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_callbacks_by_default(self) -> None:
|
||||||
|
"""process_stream works without set_callbacks (backward compat)."""
|
||||||
|
handler = StreamHandler(DisplayConfig(stream_output=True))
|
||||||
|
chunks = [_make_chunk(content="Hello")]
|
||||||
|
msg = await handler.process_stream(_async_iter(chunks))
|
||||||
|
assert msg.content == "Hello"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_calls_still_accumulated(self) -> None:
|
||||||
|
handler = StreamHandler(DisplayConfig(stream_output=True))
|
||||||
|
chunks = [
|
||||||
|
_make_tool_call_chunk(0, tc_id="call_1", name="read_file"),
|
||||||
|
_make_tool_call_chunk(0, args='{"path": "foo.py"}'),
|
||||||
|
]
|
||||||
|
msg = await handler.process_stream(_async_iter(chunks))
|
||||||
|
assert msg.tool_calls is not None
|
||||||
|
assert len(msg.tool_calls) == 1
|
||||||
|
assert msg.tool_calls[0].function.name == "read_file"
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Tests for the tool framework and core tools (Phase 4)."""
|
"""Tests for the tool framework and core tools (Phase 4)."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -48,28 +47,40 @@ class TestBaseTool:
|
|||||||
|
|
||||||
|
|
||||||
class TestPermissionsService:
|
class TestPermissionsService:
|
||||||
def test_deny_list_blocks(self) -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_deny_list_blocks(self) -> None:
|
||||||
svc = PermissionsService(PermissionsConfig(deny=["dangerous_tool"]))
|
svc = PermissionsService(PermissionsConfig(deny=["dangerous_tool"]))
|
||||||
assert svc.check("dangerous_tool") is False
|
assert await svc.check("dangerous_tool") is False
|
||||||
|
|
||||||
def test_auto_approve_allows(self) -> None:
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_approve_allows(self) -> None:
|
||||||
svc = PermissionsService(PermissionsConfig(auto_approve=["read_file"]))
|
svc = PermissionsService(PermissionsConfig(auto_approve=["read_file"]))
|
||||||
assert svc.check("read_file") is True
|
assert await svc.check("read_file") is True
|
||||||
|
|
||||||
@patch("app.services.permissions.Confirm.ask", return_value=True)
|
@pytest.mark.asyncio
|
||||||
def test_prompt_user_approved(self, mock_ask: object) -> None:
|
async def test_prompt_callback_approved(self) -> None:
|
||||||
svc = PermissionsService(PermissionsConfig(prompt_user=["write_file"]))
|
svc = PermissionsService(PermissionsConfig(prompt_user=["write_file"]))
|
||||||
assert svc.check("write_file") is True
|
|
||||||
|
|
||||||
@patch("app.services.permissions.Confirm.ask", return_value=False)
|
async def approve(tool_name: str, description: str) -> bool:
|
||||||
def test_prompt_user_denied(self, mock_ask: object) -> None:
|
return True
|
||||||
|
|
||||||
|
svc.set_prompt_callback(approve)
|
||||||
|
assert await svc.check("write_file") is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_callback_denied(self) -> None:
|
||||||
svc = PermissionsService(PermissionsConfig(prompt_user=["write_file"]))
|
svc = PermissionsService(PermissionsConfig(prompt_user=["write_file"]))
|
||||||
assert svc.check("write_file") is False
|
|
||||||
|
|
||||||
@patch("app.services.permissions.Confirm.ask", return_value=False)
|
async def deny(tool_name: str, description: str) -> bool:
|
||||||
def test_unlisted_tool_prompts(self, mock_ask: object) -> None:
|
return False
|
||||||
|
|
||||||
|
svc.set_prompt_callback(deny)
|
||||||
|
assert await svc.check("write_file") is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unlisted_tool_no_callback_denied(self) -> None:
|
||||||
svc = PermissionsService(PermissionsConfig())
|
svc = PermissionsService(PermissionsConfig())
|
||||||
assert svc.check("unknown_tool") is False
|
assert await svc.check("unknown_tool") is False
|
||||||
|
|
||||||
|
|
||||||
# --- ToolRegistry ---
|
# --- ToolRegistry ---
|
||||||
@@ -97,7 +108,7 @@ class TestToolRegistry:
|
|||||||
registry = create_default_registry(workspace, config)
|
registry = create_default_registry(workspace, config)
|
||||||
names = set(registry.get_all().keys())
|
names = set(registry.get_all().keys())
|
||||||
assert names == {
|
assert names == {
|
||||||
"read_file", "list_dir", "grep_files", "find_files",
|
"read_file", "read_many_files", "list_dir", "grep_files", "find_files",
|
||||||
"write_file", "make_dir", "delete_file",
|
"write_file", "make_dir", "delete_file",
|
||||||
"str_replace", "patch_apply",
|
"str_replace", "patch_apply",
|
||||||
"run_command",
|
"run_command",
|
||||||
@@ -107,7 +118,7 @@ class TestToolRegistry:
|
|||||||
def test_schema_export(self, workspace: Path, config: AppConfig) -> None:
|
def test_schema_export(self, workspace: Path, config: AppConfig) -> None:
|
||||||
registry = create_default_registry(workspace, config)
|
registry = create_default_registry(workspace, config)
|
||||||
schemas = registry.get_openai_tools_schema()
|
schemas = registry.get_openai_tools_schema()
|
||||||
assert len(schemas) == 11
|
assert len(schemas) == 12
|
||||||
assert all(s["type"] == "function" for s in schemas)
|
assert all(s["type"] == "function" for s in schemas)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
128
tests/unit/test_truncation.py
Normal file
128
tests/unit/test_truncation.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""Unit tests for conversation truncation logic."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.agent.context import SessionContext
|
||||||
|
from app.models.config import AgentConfig, AppConfig, LLMConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config() -> AppConfig:
|
||||||
|
return AppConfig(
|
||||||
|
llm=LLMConfig(model="test-model", endpoint="http://localhost:11434"),
|
||||||
|
agent=AgentConfig(
|
||||||
|
max_conversation_tokens=200,
|
||||||
|
truncation_keep_recent=3,
|
||||||
|
truncation_threshold=0.85,
|
||||||
|
workspace_root=Path("/tmp/test"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def ctx(config: AppConfig) -> SessionContext:
|
||||||
|
return SessionContext(config)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTruncation:
|
||||||
|
def test_no_truncation_under_threshold(self, ctx: SessionContext) -> None:
|
||||||
|
"""No messages dropped when under threshold."""
|
||||||
|
ctx.add_message("user", "Hello")
|
||||||
|
ctx.add_message("assistant", "Hi there!")
|
||||||
|
|
||||||
|
dropped = ctx.truncate_history()
|
||||||
|
assert dropped == 0
|
||||||
|
assert ctx.message_count == 2
|
||||||
|
|
||||||
|
def test_drops_oldest_messages(self, ctx: SessionContext) -> None:
|
||||||
|
"""Drops middle messages when over budget."""
|
||||||
|
# Fill with enough content to exceed the small 200-token budget
|
||||||
|
ctx.add_message("user", "First message " * 20)
|
||||||
|
for i in range(8):
|
||||||
|
ctx.add_message("assistant", f"Response {i} " * 15)
|
||||||
|
ctx.add_message("user", f"Follow-up {i} " * 15)
|
||||||
|
|
||||||
|
# Force the token counter to report over budget
|
||||||
|
from app.utils.token_counter import TokenUsage
|
||||||
|
ctx.token_counter.count_usage(TokenUsage(total_tokens=200))
|
||||||
|
|
||||||
|
original_count = len(ctx.get_history())
|
||||||
|
dropped = ctx.truncate_history()
|
||||||
|
|
||||||
|
assert dropped > 0
|
||||||
|
assert len(ctx.get_history()) < original_count
|
||||||
|
|
||||||
|
def test_preserves_recent_messages(self, ctx: SessionContext) -> None:
|
||||||
|
"""The most recent N messages are always preserved."""
|
||||||
|
ctx.add_message("user", "First message " * 20)
|
||||||
|
for i in range(10):
|
||||||
|
ctx.add_message("assistant", f"Response {i} " * 10)
|
||||||
|
ctx.add_message("user", f"Follow-up {i} " * 10)
|
||||||
|
|
||||||
|
from app.utils.token_counter import TokenUsage
|
||||||
|
ctx.token_counter.count_usage(TokenUsage(total_tokens=200))
|
||||||
|
|
||||||
|
history_before = ctx.get_history()
|
||||||
|
recent_before = history_before[-3:] # keep_recent=3
|
||||||
|
|
||||||
|
ctx.truncate_history()
|
||||||
|
|
||||||
|
history_after = ctx.get_history()
|
||||||
|
recent_after = history_after[-3:]
|
||||||
|
|
||||||
|
# Recent messages should be preserved
|
||||||
|
for before, after in zip(recent_before, recent_after):
|
||||||
|
assert before.content == after.content
|
||||||
|
|
||||||
|
def test_preserves_first_user_message(self, ctx: SessionContext) -> None:
|
||||||
|
"""First user message is always kept."""
|
||||||
|
first_content = "This is the very first user message"
|
||||||
|
ctx.add_message("user", first_content)
|
||||||
|
for i in range(10):
|
||||||
|
ctx.add_message("assistant", f"Response {i} " * 10)
|
||||||
|
ctx.add_message("user", f"Follow-up {i} " * 10)
|
||||||
|
|
||||||
|
from app.utils.token_counter import TokenUsage
|
||||||
|
ctx.token_counter.count_usage(TokenUsage(total_tokens=200))
|
||||||
|
|
||||||
|
ctx.truncate_history()
|
||||||
|
|
||||||
|
history = ctx.get_history()
|
||||||
|
assert history[0].role == "user"
|
||||||
|
assert history[0].content == first_content
|
||||||
|
|
||||||
|
def test_orphaned_tool_messages_cleaned(self, ctx: SessionContext) -> None:
|
||||||
|
"""Tool messages without matching tool_call are cleaned up."""
|
||||||
|
from app.models.tool_call import ToolCall, ToolCallFunction
|
||||||
|
|
||||||
|
ctx.add_message("user", "Do something " * 20)
|
||||||
|
# Assistant with tool call
|
||||||
|
ctx.add_message(
|
||||||
|
"assistant",
|
||||||
|
None,
|
||||||
|
tool_calls=[ToolCall(id="tc_1", type="function", function=ToolCallFunction(name="read_file", arguments='{"path": "x"}'))],
|
||||||
|
)
|
||||||
|
# Tool result for tc_1
|
||||||
|
ctx.add_message("tool", "file contents " * 20, tool_call_id="tc_1", name="read_file")
|
||||||
|
# More padding to push over budget
|
||||||
|
for i in range(8):
|
||||||
|
ctx.add_message("assistant", f"Analysis {i} " * 15)
|
||||||
|
ctx.add_message("user", f"Next {i} " * 15)
|
||||||
|
|
||||||
|
from app.utils.token_counter import TokenUsage
|
||||||
|
ctx.token_counter.count_usage(TokenUsage(total_tokens=200))
|
||||||
|
|
||||||
|
ctx.truncate_history()
|
||||||
|
|
||||||
|
history = ctx.get_history()
|
||||||
|
# If the assistant message with tc_1 was dropped, the orphaned tool message should also be gone
|
||||||
|
has_tc1_assistant = any(
|
||||||
|
m.role == "assistant" and m.tool_calls and any(tc.id == "tc_1" for tc in m.tool_calls)
|
||||||
|
for m in history
|
||||||
|
)
|
||||||
|
has_tc1_tool = any(m.role == "tool" and m.tool_call_id == "tc_1" for m in history)
|
||||||
|
|
||||||
|
# Either both exist or neither exists
|
||||||
|
assert has_tc1_assistant == has_tc1_tool
|
||||||
66
uv.lock
generated
66
uv.lock
generated
@@ -97,6 +97,18 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
|
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "linkify-it-py"
|
||||||
|
version = "2.1.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "uc-micro-py" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/2e/c9/06ea13676ef354f0af6169587ae292d3e2406e212876a413bf9eece4eb23/linkify_it_py-2.1.0.tar.gz", hash = "sha256:43360231720999c10e9328dc3691160e27a718e280673d444c38d7d3aaa3b98b", size = 29158, upload-time = "2026-03-01T07:48:47.683Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/b4/de/88b3be5c31b22333b3ca2f6ff1de4e863d8fe45aaea7485f591970ec1d3e/linkify_it_py-2.1.0-py3-none-any.whl", hash = "sha256:0d252c1594ecba2ecedc444053db5d3a9b7ec1b0dd929c8f1d74dce89f86c05e", size = 19878, upload-time = "2026-03-01T07:48:46.098Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "markdown-it-py"
|
name = "markdown-it-py"
|
||||||
version = "4.0.0"
|
version = "4.0.0"
|
||||||
@@ -109,6 +121,23 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" },
|
{ url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[package.optional-dependencies]
|
||||||
|
linkify = [
|
||||||
|
{ name = "linkify-it-py" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mdit-py-plugins"
|
||||||
|
version = "0.5.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "markdown-it-py" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/b2/fd/a756d36c0bfba5f6e39a1cdbdbfdd448dc02692467d83816dff4592a1ebc/mdit_py_plugins-0.5.0.tar.gz", hash = "sha256:f4918cb50119f50446560513a8e311d574ff6aaed72606ddae6d35716fe809c6", size = 44655, upload-time = "2025-08-11T07:25:49.083Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/fb/86/dd6e5db36df29e76c7a7699123569a4a18c1623ce68d826ed96c62643cae/mdit_py_plugins-0.5.0-py3-none-any.whl", hash = "sha256:07a08422fc1936a5d26d146759e9155ea466e842f5ab2f7d2266dd084c8dab1f", size = 57205, upload-time = "2025-08-11T07:25:47.597Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mdurl"
|
name = "mdurl"
|
||||||
version = "0.1.2"
|
version = "0.1.2"
|
||||||
@@ -127,6 +156,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" },
|
{ url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "platformdirs"
|
||||||
|
version = "4.9.4"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/19/56/8d4c30c8a1d07013911a8fdbd8f89440ef9f08d07a1b50ab8ca8be5a20f9/platformdirs-4.9.4.tar.gz", hash = "sha256:1ec356301b7dc906d83f371c8f487070e99d3ccf9e501686456394622a01a934", size = 28737, upload-time = "2026-03-05T18:34:13.271Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/63/d7/97f7e3a6abb67d8080dd406fd4df842c2be0efaf712d1c899c32a075027c/platformdirs-4.9.4-py3-none-any.whl", hash = "sha256:68a9a4619a666ea6439f2ff250c12a853cd1cbd5158d258bd824a7df6be2f868", size = 21216, upload-time = "2026-03-05T18:34:12.172Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pluggy"
|
name = "pluggy"
|
||||||
version = "1.6.0"
|
version = "1.6.0"
|
||||||
@@ -389,6 +427,7 @@ dependencies = [
|
|||||||
{ name = "pyyaml" },
|
{ name = "pyyaml" },
|
||||||
{ name = "rich" },
|
{ name = "rich" },
|
||||||
{ name = "structlog" },
|
{ name = "structlog" },
|
||||||
|
{ name = "textual" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.optional-dependencies]
|
[package.optional-dependencies]
|
||||||
@@ -408,6 +447,7 @@ requires-dist = [
|
|||||||
{ name = "rich", specifier = ">=13.0" },
|
{ name = "rich", specifier = ">=13.0" },
|
||||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.3" },
|
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.3" },
|
||||||
{ name = "structlog", specifier = ">=24.0" },
|
{ name = "structlog", specifier = ">=24.0" },
|
||||||
|
{ name = "textual", specifier = ">=4.0.0" },
|
||||||
]
|
]
|
||||||
provides-extras = ["dev"]
|
provides-extras = ["dev"]
|
||||||
|
|
||||||
@@ -420,6 +460,23 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/a8/45/a132b9074aa18e799b891b91ad72133c98d8042c70f6240e4c5f9dabee2f/structlog-25.5.0-py3-none-any.whl", hash = "sha256:a8453e9b9e636ec59bd9e79bbd4a72f025981b3ba0f5837aebf48f02f37a7f9f", size = 72510, upload-time = "2025-10-27T08:28:21.535Z" },
|
{ url = "https://files.pythonhosted.org/packages/a8/45/a132b9074aa18e799b891b91ad72133c98d8042c70f6240e4c5f9dabee2f/structlog-25.5.0-py3-none-any.whl", hash = "sha256:a8453e9b9e636ec59bd9e79bbd4a72f025981b3ba0f5837aebf48f02f37a7f9f", size = 72510, upload-time = "2025-10-27T08:28:21.535Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "textual"
|
||||||
|
version = "8.1.1"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "markdown-it-py", extra = ["linkify"] },
|
||||||
|
{ name = "mdit-py-plugins" },
|
||||||
|
{ name = "platformdirs" },
|
||||||
|
{ name = "pygments" },
|
||||||
|
{ name = "rich" },
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/72/23/8c709655c5f2208ee82ab81b8104802421865535c278a7649b842b129db1/textual-8.1.1.tar.gz", hash = "sha256:eef0256a6131f06a20ad7576412138c1f30f92ddeedd055953c08d97044bc317", size = 1843002, upload-time = "2026-03-10T10:01:38.493Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/50/21/421b02bf5943172b7a9320712a5e0d74a02a8f7597284e3f8b5b06c70b8d/textual-8.1.1-py3-none-any.whl", hash = "sha256:6712f96e335cd782e76193dee16b9c8875fe0699d923bc8d3f1228fd23e773a6", size = 719598, upload-time = "2026-03-10T10:01:48.318Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typing-extensions"
|
name = "typing-extensions"
|
||||||
version = "4.15.0"
|
version = "4.15.0"
|
||||||
@@ -440,3 +497,12 @@ sdist = { url = "https://files.pythonhosted.org/packages/55/e3/70399cb7dd41c10ac
|
|||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7", size = 14611, upload-time = "2025-10-01T02:14:40.154Z" },
|
{ url = "https://files.pythonhosted.org/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7", size = 14611, upload-time = "2025-10-01T02:14:40.154Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "uc-micro-py"
|
||||||
|
version = "2.0.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/78/67/9a363818028526e2d4579334460df777115bdec1bb77c08f9db88f6389f2/uc_micro_py-2.0.0.tar.gz", hash = "sha256:c53691e495c8db60e16ffc4861a35469b0ba0821fe409a8a7a0a71864d33a811", size = 6611, upload-time = "2026-03-01T06:31:27.526Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/61/73/d21edf5b204d1467e06500080a50f79d49ef2b997c79123a536d4a17d97c/uc_micro_py-2.0.0-py3-none-any.whl", hash = "sha256:3603a3859af53e5a39bc7677713c78ea6589ff188d70f4fee165db88e22b242c", size = 6383, upload-time = "2026-03-01T06:31:26.257Z" },
|
||||||
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user