Files
rag-llm/main.py
2025-12-31 02:01:34 +03:00

1025 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
RAG Learning System
A dual-mode RAG system designed for progressive learning with AI guidance.
Tracks your knowledge, suggests new topics, and helps identify learning gaps.
"""
import os
import sys
import json
import hashlib
import asyncio
import re
import yaml
from pathlib import Path
from collections import deque, defaultdict
from typing import List, Dict, Set
from datetime import datetime, timedelta
from dotenv import load_dotenv
from rich.console import Console
from rich.panel import Panel
from rich.table import Table
from rich.prompt import Prompt, Confirm
from rich.progress import Progress, SpinnerColumn, TextColumn
from prompt_toolkit import PromptSession
from prompt_toolkit.styles import Style
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.vectorstores.utils import filter_complex_metadata
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_ollama import OllamaEmbeddings, ChatOllama
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
# =========================
# CONFIGURATION
# =========================
console = Console(color_system="standard", force_terminal=True)
session = PromptSession()
load_dotenv()
style = Style.from_dict({"prompt": "bold #6a0dad"})
# Core Configuration
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
ANSWER_COLOR = os.getenv("ANSWER_COLOR", "blue")
# Enhanced System Prompts
SYSTEM_PROMPT_SEARCH = os.getenv("SYSTEM_PROMPT",
"You are a precise technical assistant. Use the provided context to answer questions accurately. "
"Cite sources using [filename]. If the context doesn't contain the answer, say so.")
SYSTEM_PROMPT_ANALYSIS = (
"You are an expert learning analytics tutor. Your task is to analyze a student's knowledge base "
"and provide insights about their learning progress.\n\n"
"When analyzing, consider:\n"
"1. What topics/subjects are covered in the notes\n"
"2. The depth and complexity of understanding demonstrated\n"
"3. Connections between different concepts\n"
"4. Gaps or missing fundamental concepts\n"
"5. Progression from beginner to advanced topics\n\n"
"Provide specific, actionable feedback about:\n"
"- What the student has learned well\n"
"- Areas that need more attention\n"
"- Recommended next topics to study\n"
"- How new topics connect to existing knowledge\n\n"
"Be encouraging but honest. Format your response clearly with sections."
)
SYSTEM_PROMPT_SUGGESTION = (
"You are a learning path advisor. Based on a student's current knowledge (shown in their notes), "
"suggest the next logical topics or skills to learn.\n\n"
"Your suggestions should:\n"
"1. Build upon existing knowledge\n"
"2. Fill identified gaps in understanding\n"
"3. Progress naturally from basics to advanced\n"
"4. Be specific and actionable\n\n"
"Format your response with:\n"
"- Recommended topics (with brief explanations)\n"
"- Prerequisites needed\n"
"- Why each topic is important\n"
"- Estimated difficulty level\n"
"- How it connects to what they already know"
)
USER_PROMPT_TEMPLATE = os.getenv("USER_PROMPT_TEMPLATE",
"Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}")
# Paths and Models
MD_DIRECTORY = os.getenv("MD_FOLDER", "./notes")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "mxbai-embed-large:latest")
LLM_MODEL = os.getenv("LLM_MODEL", "qwen2.5:7b-instruct-q8_0")
CHROMA_PATH = "./.cache/chroma_db"
HASH_CACHE = "./.cache/file_hashes.json"
PROGRESS_CACHE = "./.cache/learning_progress.json"
# Processing Configuration
MAX_EMBED_CHARS = 380
CHUNK_SIZE = 1200
CHUNK_OVERLAP = 200
TOP_K = 6
COLLECTION_NAME = "md_rag"
MAX_ANALYSIS_CONTEXT_CHARS = 24000
BATCH_SIZE = 10
MAX_PARALLEL_FILES = 3
# Learning Configuration
MAX_SUGGESTIONS = 5
PROGRESS_SUMMARY_DAYS = 7
# =========================
# UTILITY FUNCTIONS
# =========================
def get_file_hash(file_path: str) -> str:
"""Generate MD5 hash for file change detection"""
return hashlib.md5(Path(file_path).read_bytes()).hexdigest()
def load_json_cache(file_path: str) -> dict:
"""Load JSON cache with error handling"""
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
if Path(file_path).exists():
try:
return json.loads(Path(file_path).read_text())
except json.JSONDecodeError:
console.print(f"[yellow]⚠️ Corrupted cache: {file_path}. Resetting.[/yellow]")
return {}
return {}
def save_json_cache(cache: dict, file_path: str):
"""Save JSON cache with error handling"""
try:
Path(file_path).write_text(json.dumps(cache, indent=2))
except Exception as e:
console.print(f"[red]✗ Failed to save cache {file_path}: {e}[/red]")
def load_hash_cache() -> dict:
"""Load file hash cache"""
return load_json_cache(HASH_CACHE)
def save_hash_cache(cache: dict):
"""Save file hash cache"""
save_json_cache(cache, HASH_CACHE)
def load_progress_cache() -> dict:
"""Load learning progress cache"""
return load_json_cache(PROGRESS_CACHE)
def save_progress_cache(cache: dict):
"""Save learning progress cache"""
save_json_cache(cache, PROGRESS_CACHE)
def format_file_size(size_bytes: int) -> str:
"""Format file size for human reading"""
if size_bytes < 1024:
return f"{size_bytes} B"
elif size_bytes < 1024 * 1024:
return f"{size_bytes / 1024:.1f} KB"
else:
return f"{size_bytes / (1024 * 1024):.1f} MB"
# =========================
# INTENT CLASSIFICATION
# =========================
def classify_intent(query: str) -> str:
"""
Classify user intent into different modes:
- SEARCH: Standard RAG retrieval
- ANALYSIS: Progress and knowledge analysis
- SUGGEST: Topic and learning suggestions
- LEARN: Interactive learning mode
- STATS: Progress statistics
"""
query_lower = query.lower().strip()
# Analysis keywords (progress evaluation)
analysis_keywords = [
r"assess my progress", r"eval(uate)? my (learning|knowledge)",
r"what have i learned", r"summary of (my )?notes",
r"my progress", r"learning path", r"knowledge gap", r"analyze my",
r"оцени (мой )?прогресс", r"что я выучил", r"итоги", r"анализ знаний",
r"сегодня(?:\s+\w+)*\s*урок", r"что я изучил"
]
# Suggestion keywords
suggestion_keywords = [
r"what should i learn next", r"suggest (new )?topics", r"recommend (to )?learn",
r"next (topics|lessons)", r"learning suggestions", r"what to learn",
r"что учить дальше", r"предложи темы", r"рекомендации по обучению"
]
# Stats keywords
stats_keywords = [
r"show stats", r"learning statistics", r"progress stats", r"knowledge stats",
r"статистика обучения", r"прогресс статистика"
]
# Learning mode keywords
learn_keywords = [
r"start learning", r"learning mode", r"learn new", r"study plan",
r"начать обучение", r"режим обучения"
]
# Check patterns
for pattern in analysis_keywords:
if re.search(pattern, query_lower):
return "ANALYSIS"
for pattern in suggestion_keywords:
if re.search(pattern, query_lower):
return "SUGGEST"
for pattern in stats_keywords:
if re.search(pattern, query_lower):
return "STATS"
for pattern in learn_keywords:
if re.search(pattern, query_lower):
return "LEARN"
return "SEARCH"
# =========================
# DOCUMENT PROCESSING
# =========================
def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str]:
"""Split oversized chunks into smaller pieces"""
if len(text) <= max_chars:
return [text]
sentences = text.replace('. ', '.|').replace('! ', '!|').replace('? ', '?|').split('|')
chunks = []
current = ""
for sentence in sentences:
if len(current) + len(sentence) <= max_chars:
current += sentence
else:
if current: chunks.append(current.strip())
if len(sentence) > max_chars:
words = sentence.split()
temp = ""
for word in words:
if len(temp) + len(word) + 1 <= max_chars:
temp += word + " "
else:
if temp: chunks.append(temp.strip())
temp = word + " "
if temp: chunks.append(temp.strip())
current = ""
else:
current = sentence
if current: chunks.append(current.strip())
return [c for c in chunks if c]
def parse_markdown_with_frontmatter(file_path: str) -> tuple[dict, str]:
"""Parse markdown file and extract YAML frontmatter + content"""
content = Path(file_path).read_text(encoding='utf-8')
# YAML frontmatter pattern
frontmatter_pattern = r'^---\s*\n(.*?)\n---\s*\n(.*)$'
match = re.match(frontmatter_pattern, content, re.DOTALL)
if match:
try:
metadata = yaml.safe_load(match.group(1))
metadata = metadata if isinstance(metadata, dict) else {}
return metadata, match.group(2)
except yaml.YAMLError as e:
console.print(f"[yellow]⚠️ YAML error in {Path(file_path).name}: {e}[/yellow]")
return {}, content
return {}, content
class ChunkProcessor:
"""Handles document chunking and embedding"""
def __init__(self, vectorstore):
self.vectorstore = vectorstore
self.semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
async def process_file(self, file_path: str) -> List[Dict]:
"""Process a single markdown file into chunks"""
try:
metadata, content = parse_markdown_with_frontmatter(file_path)
metadata["source"] = file_path
if metadata.get('exclude'):
console.print(f"[dim]📋 Found excluded file: {Path(file_path).name}[/dim]")
docs = [Document(page_content=content, metadata=metadata)]
except Exception as e:
console.print(f"{Path(file_path).name}: {e}", style="red")
return []
splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
separators=["\n\n", "\n", ". ", " "]
)
chunks = []
for doc_idx, doc in enumerate(docs):
doc_metadata = doc.metadata
for chunk_idx, text in enumerate(splitter.split_text(doc.page_content)):
safe_texts = validate_chunk_size(text)
for sub_idx, safe_text in enumerate(safe_texts):
chunks.append({
"id": f"{file_path}::{doc_idx}::{chunk_idx}::{sub_idx}",
"text": safe_text,
"metadata": doc_metadata
})
return chunks
async def embed_batch(self, batch: List[Dict]) -> bool:
"""Embed a batch of chunks"""
if not batch:
return True
try:
docs = [Document(page_content=c["text"], metadata=c["metadata"]) for c in batch]
ids = [c["id"] for c in batch]
docs = filter_complex_metadata(docs)
await asyncio.to_thread(self.vectorstore.add_documents, docs, ids=ids)
return True
except Exception as e:
console.print(f"✗ Embed error: {e}", style="red")
return False
async def index_file(self, file_path: str, cache: dict) -> bool:
"""Index a single file with change detection"""
async with self.semaphore:
current_hash = get_file_hash(file_path)
if cache.get(file_path) == current_hash:
return False
chunks = await self.process_file(file_path)
if not chunks: return False
# Remove old chunks for this file
try:
self.vectorstore._collection.delete(where={"source": {"$eq": file_path}})
except:
pass
# Embed new chunks in batches
for i in range(0, len(chunks), BATCH_SIZE):
batch = chunks[i:i + BATCH_SIZE]
await self.embed_batch(batch)
cache[file_path] = current_hash
console.print(f"{Path(file_path).name} ({len(chunks)} chunks)", style="green")
return True
# =========================
# FILE WATCHER
# =========================
class DocumentWatcher(FileSystemEventHandler):
"""Watch for file changes and reindex automatically"""
def __init__(self, processor, cache):
self.processor = processor
self.cache = cache
self.queue = deque()
self.processing = False
def on_modified(self, event):
if not event.is_directory and event.src_path.endswith(".md"):
self.queue.append(event.src_path)
async def process_queue(self):
while True:
if self.queue and not self.processing:
self.processing = True
file_path = self.queue.popleft()
if Path(file_path).exists():
await self.processor.index_file(file_path, self.cache)
save_hash_cache(self.cache)
self.processing = False
await asyncio.sleep(1)
def start_watcher(processor, cache):
"""Start file system watcher"""
handler = DocumentWatcher(processor, cache)
observer = Observer()
observer.schedule(handler, MD_DIRECTORY, recursive=True)
observer.start()
asyncio.create_task(handler.process_queue())
return observer
# =========================
# CONVERSATION MEMORY
# =========================
class ConversationMemory:
"""Manage conversation history"""
def __init__(self, max_messages: int = 8):
self.messages = []
self.max_messages = max_messages
def add(self, role: str, content: str):
self.messages.append({"role": role, "content": content})
if len(self.messages) > self.max_messages:
self.messages.pop(0)
def get_history(self) -> str:
if not self.messages: return "No previous conversation."
return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages])
# =========================
# LEARNING ANALYTICS
# =========================
class LearningAnalytics:
"""Analyze learning progress and provide insights"""
def __init__(self, vectorstore):
self.vectorstore = vectorstore
async def get_knowledge_summary(self) -> dict:
"""Get comprehensive knowledge base summary"""
try:
db_data = await asyncio.to_thread(self.vectorstore.get)
if not db_data or not db_data['documents']:
return {"total_docs": 0, "total_chunks": 0, "subjects": {}}
# Filter excluded documents
filtered_pairs = [
(text, meta) for text, meta in zip(db_data['documents'], db_data['metadatas'])
if meta and not meta.get('exclude', False)
]
# Extract subjects/topics from file names and content
subjects = defaultdict(lambda: {"chunks": 0, "files": set(), "last_updated": None})
for text, meta in filtered_pairs:
source = meta.get('source', 'unknown')
filename = Path(source).stem
# Simple subject extraction from filename
subject = filename.split()[0] if filename else 'Unknown'
subjects[subject]["chunks"] += 1
subjects[subject]["files"].add(source)
# Track last update (simplified)
if not subjects[subject]["last_updated"]:
subjects[subject]["last_updated"] = datetime.now().isoformat()
# Convert sets to counts
for subject in subjects:
subjects[subject]["files"] = len(subjects[subject]["files"])
return {
"total_docs": len(filtered_pairs),
"total_chunks": len(filtered_pairs),
"subjects": dict(subjects)
}
except Exception as e:
console.print(f"[red]✗ Error getting knowledge summary: {e}[/red]")
return {"total_docs": 0, "total_chunks": 0, "subjects": {}}
async def get_learning_stats(self) -> dict:
"""Get detailed learning statistics"""
summary = await self.get_knowledge_summary()
# Load progress history
progress_cache = load_progress_cache()
stats = {
"total_topics": len(summary["subjects"]),
"total_notes": summary["total_docs"],
"total_files": sum(s["files"] for s in summary["subjects"].values()),
"topics": list(summary["subjects"].keys()),
"progress_history": progress_cache.get("sessions", []),
"study_streak": self._calculate_streak(progress_cache.get("sessions", [])),
"most_productive_topic": self._get_most_productive_topic(summary["subjects"])
}
return stats
def _calculate_streak(self, sessions: list) -> int:
"""Calculate consecutive days of studying"""
if not sessions:
return 0
# Simplified streak calculation
dates = [datetime.fromisoformat(s.get("date", datetime.now().isoformat())).date()
for s in sessions[-10:]] # Last 10 sessions
streak = 0
current_date = datetime.now().date()
for date in reversed(dates):
if (current_date - date).days <= 1:
streak += 1
current_date = date
else:
break
return streak
def _get_most_productive_topic(self, subjects: dict) -> str:
"""Identify the most studied topic"""
if not subjects:
return "None"
return max(subjects.items(), key=lambda x: x[1]["chunks"])[0]
# =========================
# CHAIN FACTORY
# =========================
def get_chain(system_prompt):
"""Create a LangChain processing chain"""
llm = ChatOllama(
model=LLM_MODEL,
temperature=0.2,
base_url=OLLAMA_BASE_URL
)
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
("human", USER_PROMPT_TEMPLATE)
])
return prompt | llm | StrOutputParser()
# =========================
# INTERACTIVE COMMANDS
# =========================
class InteractiveCommands:
"""Handle interactive learning commands"""
def __init__(self, vectorstore, analytics):
self.vectorstore = vectorstore
self.analytics = analytics
async def list_excluded_files(self):
"""List all files marked with exclude: true"""
console.print("\n[bold yellow]📋 Fetching list of excluded files...[/bold yellow]")
try:
excluded_data = await asyncio.to_thread(
self.vectorstore.get,
where={"exclude": True}
)
if not excluded_data or not excluded_data['metadatas']:
console.print("[green]✓ No files are marked for exclusion.[/green]")
return
excluded_files = set()
for meta in excluded_data['metadatas']:
if meta and 'source' in meta:
excluded_files.add(Path(meta['source']).name)
console.print(f"\n[bold red]❌ Excluded Files ({len(excluded_files)}):[/bold red]")
console.print("=" * 50, style="dim")
for filename in sorted(excluded_files):
console.print(f"{filename}", style="red")
console.print("=" * 50, style="dim")
console.print(f"[dim]Total chunks excluded: {len(excluded_data['metadatas'])}[/dim]\n")
except Exception as e:
console.print(f"[red]✗ Error fetching excluded files: {e}[/red]")
async def show_learning_stats(self):
"""Display comprehensive learning statistics"""
console.print("\n[bold cyan]📊 Learning Statistics[/bold cyan]")
console.print("=" * 60, style="dim")
stats = await self.analytics.get_learning_stats()
# Display stats in a table
table = Table(title="Knowledge Overview", show_header=False)
table.add_column("Metric", style="cyan")
table.add_column("Value", style="yellow")
table.add_row("Total Topics Studied", str(stats["total_topics"]))
table.add_row("Total Notes Created", str(stats["total_notes"]))
table.add_row("Total Files", str(stats["total_files"]))
table.add_row("Study Streak (days)", str(stats["study_streak"]))
table.add_row("Most Productive Topic", stats["most_productive_topic"])
console.print(table)
# Show topics
if stats["topics"]:
console.print(f"\n[bold green]📚 Topics Studied:[/bold green]")
for topic in sorted(stats["topics"]):
console.print(f"{topic}")
console.print()
async def interactive_learning_mode(self):
"""Start interactive learning mode"""
console.print("\n[bold magenta]🎓 Interactive Learning Mode[/bold magenta]")
console.print("I'll analyze your current knowledge and suggest what to learn next!\n")
# First, analyze current knowledge
console.print("[cyan]Analyzing your current knowledge base...[/cyan]")
# Get analysis
db_data = await asyncio.to_thread(self.vectorstore.get)
all_texts = db_data['documents']
all_metadatas = db_data['metadatas']
# Filter excluded
filtered_pairs = [
(text, meta) for text, meta in zip(all_texts, all_metadatas)
if meta and not meta.get('exclude', False)
]
if not filtered_pairs:
console.print("[yellow]⚠️ No learning materials found. Add some notes first![/yellow]")
return
# Build context for analysis
full_context = ""
for text, meta in filtered_pairs[:20]: # Limit context
full_context += f"\n---\nSource: {Path(meta['source']).name}\n{text}\n"
# Get AI analysis
chain = get_chain(SYSTEM_PROMPT_ANALYSIS)
console.print("[cyan]Getting AI analysis of your progress...[/cyan]")
analysis_response = ""
async for chunk in chain.astream({
"context": full_context,
"question": "Analyze my learning progress and identify what I've learned well and what gaps exist.",
"history": ""
}):
analysis_response += chunk
console.print(f"\n[bold green]📈 Your Learning Analysis:[/bold green]")
console.print(analysis_response)
# Get suggestions
console.print("\n[cyan]Generating personalized learning suggestions...[/cyan]")
suggestion_chain = get_chain(SYSTEM_PROMPT_SUGGESTION)
suggestion_response = ""
async for chunk in suggestion_chain.astream({
"context": full_context,
"question": "Based on this student's current knowledge, what should they learn next?",
"history": ""
}):
suggestion_response += chunk
console.print(f"\n[bold blue]💡 Recommended Next Topics:[/bold blue]")
console.print(suggestion_response)
# Save progress
progress_cache = load_progress_cache()
if "sessions" not in progress_cache:
progress_cache["sessions"] = []
progress_cache["sessions"].append({
"date": datetime.now().isoformat(),
"type": "analysis",
"topics_count": len(filtered_pairs)
})
save_progress_cache(progress_cache)
console.print(f"\n[green]✓ Analysis complete! Add notes about the suggested topics and run 'learning mode' again.[/green]")
async def suggest_topics(self):
"""Suggest new topics to learn"""
console.print("\n[bold blue]💡 Topic Suggestions[/bold blue]")
# Get current knowledge
db_data = await asyncio.to_thread(self.vectorstore.get)
all_texts = db_data['documents']
all_metadatas = db_data['metadatas']
filtered_pairs = [
(text, meta) for text, meta in zip(all_texts, all_metadatas)
if meta and not meta.get('exclude', False)
][:15] # Limit context
if not filtered_pairs:
console.print("[yellow]⚠️ No notes found. Start by creating some learning materials![/yellow]")
return
# Build context
context = ""
for text, meta in filtered_pairs:
context += f"\n---\nSource: {Path(meta['source']).name}\n{text}\n"
# Get suggestions from AI
chain = get_chain(SYSTEM_PROMPT_SUGGESTION)
console.print("[cyan]Analyzing your knowledge and generating suggestions...[/cyan]\n")
response = ""
async for chunk in chain.astream({
"context": context,
"question": "What are the next logical topics for this student to learn?",
"history": ""
}):
response += chunk
console.print(chunk, end="")
console.print("\n")
async def exclude_file_interactive(self):
"""Interactively exclude a file from learning analysis"""
console.print("\n[bold yellow]📁 Exclude File from Analysis[/bold yellow]")
# List all non-excluded files
db_data = await asyncio.to_thread(self.vectorstore.get)
files = set()
for meta in db_data['metadatas']:
if meta and 'source' in meta and not meta.get('exclude', False):
files.add(meta['source'])
if not files:
console.print("[yellow]⚠️ No files found to exclude.[/yellow]")
return
# Show files
file_list = sorted(list(files))
console.print("\n[bold]Available files:[/bold]")
for i, file_path in enumerate(file_list, 1):
console.print(f" {i}. {Path(file_path).name}")
# Get user choice
choice = Prompt.ask("\nSelect file number to exclude",
choices=[str(i) for i in range(1, len(file_list) + 1)],
default="1")
selected_file = file_list[int(choice) - 1]
# Confirmation
if Confirm.ask(f"\nExclude '{Path(selected_file).name}' from learning analysis?"):
# Update the file's metadata in vectorstore
try:
# Note: In a real implementation, you'd need to update the file's frontmatter
# For now, we'll show instructions
console.print(f"\n[red]⚠️ Manual action required:[/red]")
console.print(f"Add 'exclude: true' to the frontmatter of:")
console.print(f" {selected_file}")
console.print(f"\n[dim]Example:[/dim]")
console.print("```\n---\nexclude: true\n---\n```")
console.print(f"\n[green]The file will be excluded on next reindex.[/green]")
except Exception as e:
console.print(f"[red]✗ Error: {e}[/red]")
# =========================
# MAIN APPLICATION
# =========================
async def main():
"""Main application entry point"""
# Setup directories
Path(MD_DIRECTORY).mkdir(parents=True, exist_ok=True)
Path(CHROMA_PATH).parent.mkdir(parents=True, exist_ok=True)
# Display welcome banner
console.print(Panel.fit(
f"[bold cyan]⚡ RAG Learning System[/bold cyan]\n"
f"📂 Notes Directory: {MD_DIRECTORY}\n"
f"🧠 Embedding Model: {EMBEDDING_MODEL}\n"
f"🤖 LLM Model: {LLM_MODEL}\n"
f"[dim]Commands: /help for available commands[/dim]",
border_style="cyan"
))
# Initialize components
embeddings = OllamaEmbeddings(
model=EMBEDDING_MODEL,
base_url=OLLAMA_BASE_URL
)
vectorstore = Chroma(
collection_name=COLLECTION_NAME,
persist_directory=CHROMA_PATH,
embedding_function=embeddings
)
processor = ChunkProcessor(vectorstore)
analytics = LearningAnalytics(vectorstore)
commands = InteractiveCommands(vectorstore, analytics)
cache = load_hash_cache()
# Index existing documents
console.print(f"\n[bold yellow]📚 Indexing documents...[/bold yellow]")
files = [
os.path.join(root, file)
for root, _, files in os.walk(MD_DIRECTORY)
for file in files if file.endswith(".md")
]
semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
async def sem_task(fp):
async with semaphore:
return await processor.index_file(fp, cache)
# Use progress bar for indexing
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console
) as progress:
task = progress.add_task("Indexing files...", total=len(files))
tasks = [sem_task(fp) for fp in files]
for fut in asyncio.as_completed(tasks):
await fut
progress.advance(task)
save_hash_cache(cache)
# Start file watcher
observer = start_watcher(processor, cache)
memory = ConversationMemory()
# Show help hint
console.print(f"\n[dim]💡 Type /help to see available commands[/dim]\n")
try:
while True:
# Get user input
query = await session.prompt_async("> ", style=style)
query = query.strip()
if not query:
continue
# Handle commands
if query.startswith('/'):
command = query[1:].lower().strip()
if command in ['exit', 'quit', 'q']:
console.print("\n👋 Goodbye!", style="yellow")
break
elif command in ['help', 'h']:
await show_help()
elif command in ['stats', 'statistics']:
await commands.show_learning_stats()
elif command in ['excluded', 'list-excluded']:
await commands.list_excluded_files()
elif command in ['learning-mode', 'learn']:
await commands.interactive_learning_mode()
elif command in ['suggest', 'suggestions']:
await commands.suggest_topics()
elif command in ['exclude']:
await commands.exclude_file_interactive()
elif command in ['reindex']:
console.print("\n[yellow]🔄 Reindexing all files...[/yellow]")
cache.clear()
for file_path in files:
await processor.index_file(file_path, cache)
save_hash_cache(cache)
console.print("[green]✓ Reindexing complete![/green]")
else:
console.print(f"[red]✗ Unknown command: {command}[/red]")
console.print("[dim]Type /help to see available commands[/dim]")
continue
# Process normal queries
console.print()
mode = classify_intent(query)
history_str = memory.get_history()
if mode == "SEARCH":
console.print("🔍 SEARCH MODE (Top-K Retrieval)", style="bold blue")
retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K})
docs = await asyncio.to_thread(retriever.invoke, query)
context_str = "\n\n".join(
f"[{Path(d.metadata['source']).name}]\n{d.page_content}"
for d in docs
)
chain = get_chain(SYSTEM_PROMPT_SEARCH)
elif mode == "ANALYSIS":
console.print("📊 ANALYSIS MODE (Full Context Evaluation)", style="bold magenta")
db_data = await asyncio.to_thread(vectorstore.get)
all_texts = db_data['documents']
all_metas = db_data['metadatas']
if not all_texts:
console.print("[red]No documents found to analyze![/red]")
continue
# Filter excluded chunks
filtered_pairs = [
(text, meta) for text, meta in zip(all_texts, all_metas)
if meta and not meta.get('exclude', False)
]
excluded_count = len(all_texts) - len(filtered_pairs)
if excluded_count > 0:
console.print(f" Excluded {excluded_count} chunks marked 'exclude: true'", style="dim")
if not filtered_pairs:
console.print("[yellow]All documents are marked for exclusion. Nothing to analyze.[/yellow]")
continue
# Build context
full_context = ""
char_count = 0
for text, meta in filtered_pairs[:25]: # Limit for analysis
entry = f"\n---\nSource: {Path(meta['source']).name}\n{text}\n"
if char_count + len(entry) > MAX_ANALYSIS_CONTEXT_CHARS:
full_context += "\n[...Truncated due to context limit...]"
console.print("⚠ Context limit reached, truncating analysis data.", style="yellow")
break
full_context += entry
char_count += len(entry)
context_str = full_context
chain = get_chain(SYSTEM_PROMPT_ANALYSIS)
elif mode == "SUGGEST":
await commands.suggest_topics()
continue
elif mode == "STATS":
await commands.show_learning_stats()
continue
elif mode == "LEARN":
await commands.interactive_learning_mode()
continue
# Generate and display response
response = ""
console.print(f"Context size: {len(context_str)} chars", style="dim")
console.print("Assistant:", style="blue", end=" ")
async for chunk in chain.astream({
"context": context_str,
"question": query,
"history": history_str
}):
console.print(chunk, end="", style=ANSWER_COLOR)
response += chunk
console.print("\n")
# Update conversation memory
memory.add("user", query)
memory.add("assistant", response)
finally:
# Cleanup
observer.stop()
observer.join()
async def show_help():
"""Display help information"""
console.print("\n[bold cyan]📖 Available Commands:[/bold cyan]")
console.print("=" * 50, style="dim")
commands = [
("/help", "Show this help message"),
("/stats", "Display learning statistics and progress"),
("/learning-mode", "Start interactive learning analysis"),
("/suggest", "Get topic suggestions for next study"),
("/excluded", "List files excluded from analysis"),
("/exclude", "Interactively exclude a file"),
("/reindex", "Reindex all documents"),
("/exit, /quit, /q", "Exit the application"),
]
for cmd, desc in commands:
console.print(f"[yellow]{cmd:<20}[/yellow] {desc}")
console.print("\n[bold cyan]🎯 Learning Modes:[/bold cyan]")
console.print("=" * 50, style="dim")
console.print("• [blue]Search Mode[/blue]: Ask questions about your notes")
console.print("• [magenta]Analysis Mode[/magenta]: Get progress evaluation")
console.print("• [green]Suggestion Mode[/green]: Get topic recommendations")
console.print("\n[bold cyan]💡 Examples:[/bold cyan]")
console.print("=" * 50, style="dim")
console.print("\"What is SQL JOIN?\" → Search your notes")
console.print("\"Assess my progress\" → Analyze learning")
console.print("\"What should I learn next?\" → Get suggestions")
console.print("\"Show my statistics\" → Display progress")
console.print()
if __name__ == "__main__":
import nest_asyncio
nest_asyncio.apply()
try:
import asyncio
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
except KeyboardInterrupt:
console.print("\n👋 Goodbye!", style="yellow")
sys.exit(0)
except Exception as e:
console.print(f"\n[red]✗ Unexpected error: {e}[/red]")
sys.exit(1)