From 6af63cf8f1295d3618ba1013f7c919a0350c4941 Mon Sep 17 00:00:00 2001 From: y9938 Date: Wed, 31 Dec 2025 02:01:34 +0300 Subject: [PATCH] Test new script --- main.py | 736 +++++++++++++++++++++++++++++++++++++++++++++---- pyproject.toml | 1 + uv.lock | 2 + 3 files changed, 681 insertions(+), 58 deletions(-) diff --git a/main.py b/main.py index 7daf7ca..d34501f 100644 --- a/main.py +++ b/main.py @@ -1,21 +1,33 @@ #!/usr/bin/env python3 +""" +RAG Learning System +A dual-mode RAG system designed for progressive learning with AI guidance. +Tracks your knowledge, suggests new topics, and helps identify learning gaps. +""" + import os import sys import json import hashlib import asyncio import re +import yaml from pathlib import Path -from collections import deque -from typing import List, Dict +from collections import deque, defaultdict +from typing import List, Dict, Set +from datetime import datetime, timedelta from dotenv import load_dotenv from rich.console import Console from rich.panel import Panel +from rich.table import Table +from rich.prompt import Prompt, Confirm +from rich.progress import Progress, SpinnerColumn, TextColumn from prompt_toolkit import PromptSession from prompt_toolkit.styles import Style from langchain_community.document_loaders import UnstructuredMarkdownLoader +from langchain_community.vectorstores.utils import filter_complex_metadata from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_ollama import OllamaEmbeddings, ChatOllama from langchain_chroma import Chroma @@ -27,7 +39,7 @@ from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler # ========================= -# CONFIG +# CONFIGURATION # ========================= console = Console(color_system="standard", force_terminal=True) session = PromptSession() @@ -35,76 +47,191 @@ load_dotenv() style = Style.from_dict({"prompt": "bold #6a0dad"}) +# Core Configuration OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") ANSWER_COLOR = os.getenv("ANSWER_COLOR", "blue") -SYSTEM_PROMPT_SEARCH = os.getenv("SYSTEM_PROMPT", "You are a precise technical assistant. Cite sources using [filename]. Be concise.") +# Enhanced System Prompts +SYSTEM_PROMPT_SEARCH = os.getenv("SYSTEM_PROMPT", + "You are a precise technical assistant. Use the provided context to answer questions accurately. " + "Cite sources using [filename]. If the context doesn't contain the answer, say so.") + SYSTEM_PROMPT_ANALYSIS = ( - "You are an expert tutor and progress evaluator. " - "You have access to the student's entire knowledge base below. " - "Analyze the coverage, depth, and connections in the notes. " - "Identify what the user has learned well, what is missing, and suggest the next logical steps. " - "Do not just summarize; evaluate the progress." + "You are an expert learning analytics tutor. Your task is to analyze a student's knowledge base " + "and provide insights about their learning progress.\n\n" + "When analyzing, consider:\n" + "1. What topics/subjects are covered in the notes\n" + "2. The depth and complexity of understanding demonstrated\n" + "3. Connections between different concepts\n" + "4. Gaps or missing fundamental concepts\n" + "5. Progression from beginner to advanced topics\n\n" + "Provide specific, actionable feedback about:\n" + "- What the student has learned well\n" + "- Areas that need more attention\n" + "- Recommended next topics to study\n" + "- How new topics connect to existing knowledge\n\n" + "Be encouraging but honest. Format your response clearly with sections." +) + +SYSTEM_PROMPT_SUGGESTION = ( + "You are a learning path advisor. Based on a student's current knowledge (shown in their notes), " + "suggest the next logical topics or skills to learn.\n\n" + "Your suggestions should:\n" + "1. Build upon existing knowledge\n" + "2. Fill identified gaps in understanding\n" + "3. Progress naturally from basics to advanced\n" + "4. Be specific and actionable\n\n" + "Format your response with:\n" + "- Recommended topics (with brief explanations)\n" + "- Prerequisites needed\n" + "- Why each topic is important\n" + "- Estimated difficulty level\n" + "- How it connects to what they already know" ) USER_PROMPT_TEMPLATE = os.getenv("USER_PROMPT_TEMPLATE", "Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}") +# Paths and Models MD_DIRECTORY = os.getenv("MD_FOLDER", "./notes") -EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nomic-embed-text") -LLM_MODEL = os.getenv("LLM_MODEL", "llama3") +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "mxbai-embed-large:latest") +LLM_MODEL = os.getenv("LLM_MODEL", "qwen2.5:7b-instruct-q8_0") CHROMA_PATH = "./.cache/chroma_db" HASH_CACHE = "./.cache/file_hashes.json" +PROGRESS_CACHE = "./.cache/learning_progress.json" -MAX_EMBED_CHARS = 380 -CHUNK_SIZE = 1200 +# Processing Configuration +MAX_EMBED_CHARS = 380 +CHUNK_SIZE = 1200 CHUNK_OVERLAP = 200 TOP_K = 6 COLLECTION_NAME = "md_rag" -MAX_ANALYSIS_CONTEXT_CHARS = 24000 - -BATCH_SIZE = 10 +MAX_ANALYSIS_CONTEXT_CHARS = 24000 +BATCH_SIZE = 10 MAX_PARALLEL_FILES = 3 +# Learning Configuration +MAX_SUGGESTIONS = 5 +PROGRESS_SUMMARY_DAYS = 7 + # ========================= -# UTILS & CACHE +# UTILITY FUNCTIONS # ========================= def get_file_hash(file_path: str) -> str: + """Generate MD5 hash for file change detection""" return hashlib.md5(Path(file_path).read_bytes()).hexdigest() -def load_hash_cache() -> dict: - Path(HASH_CACHE).parent.mkdir(parents=True, exist_ok=True) - if Path(HASH_CACHE).exists(): - return json.loads(Path(HASH_CACHE).read_text()) +def load_json_cache(file_path: str) -> dict: + """Load JSON cache with error handling""" + Path(file_path).parent.mkdir(parents=True, exist_ok=True) + if Path(file_path).exists(): + try: + return json.loads(Path(file_path).read_text()) + except json.JSONDecodeError: + console.print(f"[yellow]⚠️ Corrupted cache: {file_path}. Resetting.[/yellow]") + return {} return {} +def save_json_cache(cache: dict, file_path: str): + """Save JSON cache with error handling""" + try: + Path(file_path).write_text(json.dumps(cache, indent=2)) + except Exception as e: + console.print(f"[red]✗ Failed to save cache {file_path}: {e}[/red]") + +def load_hash_cache() -> dict: + """Load file hash cache""" + return load_json_cache(HASH_CACHE) + def save_hash_cache(cache: dict): - Path(HASH_CACHE).write_text(json.dumps(cache, indent=2)) + """Save file hash cache""" + save_json_cache(cache, HASH_CACHE) + +def load_progress_cache() -> dict: + """Load learning progress cache""" + return load_json_cache(PROGRESS_CACHE) + +def save_progress_cache(cache: dict): + """Save learning progress cache""" + save_json_cache(cache, PROGRESS_CACHE) + +def format_file_size(size_bytes: int) -> str: + """Format file size for human reading""" + if size_bytes < 1024: + return f"{size_bytes} B" + elif size_bytes < 1024 * 1024: + return f"{size_bytes / 1024:.1f} KB" + else: + return f"{size_bytes / (1024 * 1024):.1f} MB" # ========================= -# ROUTING LOGIC +# INTENT CLASSIFICATION # ========================= def classify_intent(query: str) -> str: + """ + Classify user intent into different modes: + - SEARCH: Standard RAG retrieval + - ANALYSIS: Progress and knowledge analysis + - SUGGEST: Topic and learning suggestions + - LEARN: Interactive learning mode + - STATS: Progress statistics + """ + query_lower = query.lower().strip() + + # Analysis keywords (progress evaluation) analysis_keywords = [ r"assess my progress", r"eval(uate)? my (learning|knowledge)", r"what have i learned", r"summary of (my )?notes", - r"my progress", r"learning path", r"knowledge gap", + r"my progress", r"learning path", r"knowledge gap", r"analyze my", r"оцени (мой )?прогресс", r"что я выучил", r"итоги", r"анализ знаний", r"сегодня(?:\s+\w+)*\s*урок", r"что я изучил" ] - query_lower = query.lower() + # Suggestion keywords + suggestion_keywords = [ + r"what should i learn next", r"suggest (new )?topics", r"recommend (to )?learn", + r"next (topics|lessons)", r"learning suggestions", r"what to learn", + r"что учить дальше", r"предложи темы", r"рекомендации по обучению" + ] + + # Stats keywords + stats_keywords = [ + r"show stats", r"learning statistics", r"progress stats", r"knowledge stats", + r"статистика обучения", r"прогресс статистика" + ] + + # Learning mode keywords + learn_keywords = [ + r"start learning", r"learning mode", r"learn new", r"study plan", + r"начать обучение", r"режим обучения" + ] + + # Check patterns for pattern in analysis_keywords: if re.search(pattern, query_lower): return "ANALYSIS" + + for pattern in suggestion_keywords: + if re.search(pattern, query_lower): + return "SUGGEST" + + for pattern in stats_keywords: + if re.search(pattern, query_lower): + return "STATS" + + for pattern in learn_keywords: + if re.search(pattern, query_lower): + return "LEARN" + return "SEARCH" # ========================= # DOCUMENT PROCESSING # ========================= def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str]: + """Split oversized chunks into smaller pieces""" if len(text) <= max_chars: return [text] @@ -134,14 +261,42 @@ def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str if current: chunks.append(current.strip()) return [c for c in chunks if c] +def parse_markdown_with_frontmatter(file_path: str) -> tuple[dict, str]: + """Parse markdown file and extract YAML frontmatter + content""" + content = Path(file_path).read_text(encoding='utf-8') + + # YAML frontmatter pattern + frontmatter_pattern = r'^---\s*\n(.*?)\n---\s*\n(.*)$' + match = re.match(frontmatter_pattern, content, re.DOTALL) + + if match: + try: + metadata = yaml.safe_load(match.group(1)) + metadata = metadata if isinstance(metadata, dict) else {} + return metadata, match.group(2) + except yaml.YAMLError as e: + console.print(f"[yellow]⚠️ YAML error in {Path(file_path).name}: {e}[/yellow]") + return {}, content + + return {}, content + class ChunkProcessor: + """Handles document chunking and embedding""" def __init__(self, vectorstore): self.vectorstore = vectorstore self.semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES) async def process_file(self, file_path: str) -> List[Dict]: + """Process a single markdown file into chunks""" try: - docs = await asyncio.to_thread(UnstructuredMarkdownLoader(file_path).load) + metadata, content = parse_markdown_with_frontmatter(file_path) + metadata["source"] = file_path + + if metadata.get('exclude'): + console.print(f"[dim]📋 Found excluded file: {Path(file_path).name}[/dim]") + + docs = [Document(page_content=content, metadata=metadata)] + except Exception as e: console.print(f"✗ {Path(file_path).name}: {e}", style="red") return [] @@ -154,28 +309,37 @@ class ChunkProcessor: chunks = [] for doc_idx, doc in enumerate(docs): + doc_metadata = doc.metadata + for chunk_idx, text in enumerate(splitter.split_text(doc.page_content)): safe_texts = validate_chunk_size(text) for sub_idx, safe_text in enumerate(safe_texts): chunks.append({ "id": f"{file_path}::{doc_idx}::{chunk_idx}::{sub_idx}", "text": safe_text, - "metadata": {"source": file_path, **doc.metadata} + "metadata": doc_metadata }) return chunks async def embed_batch(self, batch: List[Dict]) -> bool: - if not batch: return True + """Embed a batch of chunks""" + if not batch: + return True + try: docs = [Document(page_content=c["text"], metadata=c["metadata"]) for c in batch] ids = [c["id"] for c in batch] + + docs = filter_complex_metadata(docs) await asyncio.to_thread(self.vectorstore.add_documents, docs, ids=ids) return True + except Exception as e: console.print(f"✗ Embed error: {e}", style="red") return False async def index_file(self, file_path: str, cache: dict) -> bool: + """Index a single file with change detection""" async with self.semaphore: current_hash = get_file_hash(file_path) if cache.get(file_path) == current_hash: @@ -184,11 +348,13 @@ class ChunkProcessor: chunks = await self.process_file(file_path) if not chunks: return False + # Remove old chunks for this file try: - self.vectorstore._collection.delete(where={"source": file_path}) + self.vectorstore._collection.delete(where={"source": {"$eq": file_path}}) except: - pass + pass + # Embed new chunks in batches for i in range(0, len(chunks), BATCH_SIZE): batch = chunks[i:i + BATCH_SIZE] await self.embed_batch(batch) @@ -201,6 +367,7 @@ class ChunkProcessor: # FILE WATCHER # ========================= class DocumentWatcher(FileSystemEventHandler): + """Watch for file changes and reindex automatically""" def __init__(self, processor, cache): self.processor = processor self.cache = cache @@ -223,6 +390,7 @@ class DocumentWatcher(FileSystemEventHandler): await asyncio.sleep(1) def start_watcher(processor, cache): + """Start file system watcher""" handler = DocumentWatcher(processor, cache) observer = Observer() observer.schedule(handler, MD_DIRECTORY, recursive=True) @@ -231,9 +399,10 @@ def start_watcher(processor, cache): return observer # ========================= -# RAG CHAIN FACTORY +# CONVERSATION MEMORY # ========================= class ConversationMemory: + """Manage conversation history""" def __init__(self, max_messages: int = 8): self.messages = [] self.max_messages = max_messages @@ -247,7 +416,112 @@ class ConversationMemory: if not self.messages: return "No previous conversation." return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages]) +# ========================= +# LEARNING ANALYTICS +# ========================= +class LearningAnalytics: + """Analyze learning progress and provide insights""" + + def __init__(self, vectorstore): + self.vectorstore = vectorstore + + async def get_knowledge_summary(self) -> dict: + """Get comprehensive knowledge base summary""" + try: + db_data = await asyncio.to_thread(self.vectorstore.get) + + if not db_data or not db_data['documents']: + return {"total_docs": 0, "total_chunks": 0, "subjects": {}} + + # Filter excluded documents + filtered_pairs = [ + (text, meta) for text, meta in zip(db_data['documents'], db_data['metadatas']) + if meta and not meta.get('exclude', False) + ] + + # Extract subjects/topics from file names and content + subjects = defaultdict(lambda: {"chunks": 0, "files": set(), "last_updated": None}) + + for text, meta in filtered_pairs: + source = meta.get('source', 'unknown') + filename = Path(source).stem + + # Simple subject extraction from filename + subject = filename.split()[0] if filename else 'Unknown' + + subjects[subject]["chunks"] += 1 + subjects[subject]["files"].add(source) + + # Track last update (simplified) + if not subjects[subject]["last_updated"]: + subjects[subject]["last_updated"] = datetime.now().isoformat() + + # Convert sets to counts + for subject in subjects: + subjects[subject]["files"] = len(subjects[subject]["files"]) + + return { + "total_docs": len(filtered_pairs), + "total_chunks": len(filtered_pairs), + "subjects": dict(subjects) + } + + except Exception as e: + console.print(f"[red]✗ Error getting knowledge summary: {e}[/red]") + return {"total_docs": 0, "total_chunks": 0, "subjects": {}} + + async def get_learning_stats(self) -> dict: + """Get detailed learning statistics""" + summary = await self.get_knowledge_summary() + + # Load progress history + progress_cache = load_progress_cache() + + stats = { + "total_topics": len(summary["subjects"]), + "total_notes": summary["total_docs"], + "total_files": sum(s["files"] for s in summary["subjects"].values()), + "topics": list(summary["subjects"].keys()), + "progress_history": progress_cache.get("sessions", []), + "study_streak": self._calculate_streak(progress_cache.get("sessions", [])), + "most_productive_topic": self._get_most_productive_topic(summary["subjects"]) + } + + return stats + + def _calculate_streak(self, sessions: list) -> int: + """Calculate consecutive days of studying""" + if not sessions: + return 0 + + # Simplified streak calculation + dates = [datetime.fromisoformat(s.get("date", datetime.now().isoformat())).date() + for s in sessions[-10:]] # Last 10 sessions + + streak = 0 + current_date = datetime.now().date() + + for date in reversed(dates): + if (current_date - date).days <= 1: + streak += 1 + current_date = date + else: + break + + return streak + + def _get_most_productive_topic(self, subjects: dict) -> str: + """Identify the most studied topic""" + if not subjects: + return "None" + + return max(subjects.items(), key=lambda x: x[1]["chunks"])[0] + +# ========================= +# CHAIN FACTORY +# ========================= def get_chain(system_prompt): + """Create a LangChain processing chain""" llm = ChatOllama( model=LLM_MODEL, temperature=0.2, @@ -260,24 +534,256 @@ def get_chain(system_prompt): return prompt | llm | StrOutputParser() # ========================= -# MAIN +# INTERACTIVE COMMANDS +# ========================= +class InteractiveCommands: + """Handle interactive learning commands""" + + def __init__(self, vectorstore, analytics): + self.vectorstore = vectorstore + self.analytics = analytics + + async def list_excluded_files(self): + """List all files marked with exclude: true""" + console.print("\n[bold yellow]📋 Fetching list of excluded files...[/bold yellow]") + + try: + excluded_data = await asyncio.to_thread( + self.vectorstore.get, + where={"exclude": True} + ) + + if not excluded_data or not excluded_data['metadatas']: + console.print("[green]✓ No files are marked for exclusion.[/green]") + return + + excluded_files = set() + for meta in excluded_data['metadatas']: + if meta and 'source' in meta: + excluded_files.add(Path(meta['source']).name) + + console.print(f"\n[bold red]❌ Excluded Files ({len(excluded_files)}):[/bold red]") + console.print("=" * 50, style="dim") + + for filename in sorted(excluded_files): + console.print(f" • {filename}", style="red") + + console.print("=" * 50, style="dim") + console.print(f"[dim]Total chunks excluded: {len(excluded_data['metadatas'])}[/dim]\n") + + except Exception as e: + console.print(f"[red]✗ Error fetching excluded files: {e}[/red]") + + async def show_learning_stats(self): + """Display comprehensive learning statistics""" + console.print("\n[bold cyan]📊 Learning Statistics[/bold cyan]") + console.print("=" * 60, style="dim") + + stats = await self.analytics.get_learning_stats() + + # Display stats in a table + table = Table(title="Knowledge Overview", show_header=False) + table.add_column("Metric", style="cyan") + table.add_column("Value", style="yellow") + + table.add_row("Total Topics Studied", str(stats["total_topics"])) + table.add_row("Total Notes Created", str(stats["total_notes"])) + table.add_row("Total Files", str(stats["total_files"])) + table.add_row("Study Streak (days)", str(stats["study_streak"])) + table.add_row("Most Productive Topic", stats["most_productive_topic"]) + + console.print(table) + + # Show topics + if stats["topics"]: + console.print(f"\n[bold green]📚 Topics Studied:[/bold green]") + for topic in sorted(stats["topics"]): + console.print(f" ✓ {topic}") + + console.print() + + async def interactive_learning_mode(self): + """Start interactive learning mode""" + console.print("\n[bold magenta]🎓 Interactive Learning Mode[/bold magenta]") + console.print("I'll analyze your current knowledge and suggest what to learn next!\n") + + # First, analyze current knowledge + console.print("[cyan]Analyzing your current knowledge base...[/cyan]") + + # Get analysis + db_data = await asyncio.to_thread(self.vectorstore.get) + all_texts = db_data['documents'] + all_metadatas = db_data['metadatas'] + + # Filter excluded + filtered_pairs = [ + (text, meta) for text, meta in zip(all_texts, all_metadatas) + if meta and not meta.get('exclude', False) + ] + + if not filtered_pairs: + console.print("[yellow]⚠️ No learning materials found. Add some notes first![/yellow]") + return + + # Build context for analysis + full_context = "" + for text, meta in filtered_pairs[:20]: # Limit context + full_context += f"\n---\nSource: {Path(meta['source']).name}\n{text}\n" + + # Get AI analysis + chain = get_chain(SYSTEM_PROMPT_ANALYSIS) + + console.print("[cyan]Getting AI analysis of your progress...[/cyan]") + analysis_response = "" + async for chunk in chain.astream({ + "context": full_context, + "question": "Analyze my learning progress and identify what I've learned well and what gaps exist.", + "history": "" + }): + analysis_response += chunk + + console.print(f"\n[bold green]📈 Your Learning Analysis:[/bold green]") + console.print(analysis_response) + + # Get suggestions + console.print("\n[cyan]Generating personalized learning suggestions...[/cyan]") + + suggestion_chain = get_chain(SYSTEM_PROMPT_SUGGESTION) + suggestion_response = "" + async for chunk in suggestion_chain.astream({ + "context": full_context, + "question": "Based on this student's current knowledge, what should they learn next?", + "history": "" + }): + suggestion_response += chunk + + console.print(f"\n[bold blue]💡 Recommended Next Topics:[/bold blue]") + console.print(suggestion_response) + + # Save progress + progress_cache = load_progress_cache() + if "sessions" not in progress_cache: + progress_cache["sessions"] = [] + + progress_cache["sessions"].append({ + "date": datetime.now().isoformat(), + "type": "analysis", + "topics_count": len(filtered_pairs) + }) + + save_progress_cache(progress_cache) + + console.print(f"\n[green]✓ Analysis complete! Add notes about the suggested topics and run 'learning mode' again.[/green]") + + async def suggest_topics(self): + """Suggest new topics to learn""" + console.print("\n[bold blue]💡 Topic Suggestions[/bold blue]") + + # Get current knowledge + db_data = await asyncio.to_thread(self.vectorstore.get) + all_texts = db_data['documents'] + all_metadatas = db_data['metadatas'] + + filtered_pairs = [ + (text, meta) for text, meta in zip(all_texts, all_metadatas) + if meta and not meta.get('exclude', False) + ][:15] # Limit context + + if not filtered_pairs: + console.print("[yellow]⚠️ No notes found. Start by creating some learning materials![/yellow]") + return + + # Build context + context = "" + for text, meta in filtered_pairs: + context += f"\n---\nSource: {Path(meta['source']).name}\n{text}\n" + + # Get suggestions from AI + chain = get_chain(SYSTEM_PROMPT_SUGGESTION) + + console.print("[cyan]Analyzing your knowledge and generating suggestions...[/cyan]\n") + + response = "" + async for chunk in chain.astream({ + "context": context, + "question": "What are the next logical topics for this student to learn?", + "history": "" + }): + response += chunk + console.print(chunk, end="") + + console.print("\n") + + async def exclude_file_interactive(self): + """Interactively exclude a file from learning analysis""" + console.print("\n[bold yellow]📁 Exclude File from Analysis[/bold yellow]") + + # List all non-excluded files + db_data = await asyncio.to_thread(self.vectorstore.get) + files = set() + + for meta in db_data['metadatas']: + if meta and 'source' in meta and not meta.get('exclude', False): + files.add(meta['source']) + + if not files: + console.print("[yellow]⚠️ No files found to exclude.[/yellow]") + return + + # Show files + file_list = sorted(list(files)) + console.print("\n[bold]Available files:[/bold]") + for i, file_path in enumerate(file_list, 1): + console.print(f" {i}. {Path(file_path).name}") + + # Get user choice + choice = Prompt.ask("\nSelect file number to exclude", + choices=[str(i) for i in range(1, len(file_list) + 1)], + default="1") + + selected_file = file_list[int(choice) - 1] + + # Confirmation + if Confirm.ask(f"\nExclude '{Path(selected_file).name}' from learning analysis?"): + # Update the file's metadata in vectorstore + try: + # Note: In a real implementation, you'd need to update the file's frontmatter + # For now, we'll show instructions + console.print(f"\n[red]⚠️ Manual action required:[/red]") + console.print(f"Add 'exclude: true' to the frontmatter of:") + console.print(f" {selected_file}") + console.print(f"\n[dim]Example:[/dim]") + console.print("```\n---\nexclude: true\n---\n```") + console.print(f"\n[green]The file will be excluded on next reindex.[/green]") + except Exception as e: + console.print(f"[red]✗ Error: {e}[/red]") + +# ========================= +# MAIN APPLICATION # ========================= async def main(): + """Main application entry point""" + + # Setup directories Path(MD_DIRECTORY).mkdir(parents=True, exist_ok=True) Path(CHROMA_PATH).parent.mkdir(parents=True, exist_ok=True) + # Display welcome banner console.print(Panel.fit( - f"[bold cyan]⚡ Dual-Mode RAG System[/bold cyan]\n" - f"📂 Docs: {MD_DIRECTORY}\n" - f"🧠 Embed: {EMBEDDING_MODEL}\n" - f"🤖 LLM: {LLM_MODEL}", + f"[bold cyan]⚡ RAG Learning System[/bold cyan]\n" + f"📂 Notes Directory: {MD_DIRECTORY}\n" + f"🧠 Embedding Model: {EMBEDDING_MODEL}\n" + f"🤖 LLM Model: {LLM_MODEL}\n" + f"[dim]Commands: /help for available commands[/dim]", border_style="cyan" )) + # Initialize components embeddings = OllamaEmbeddings( model=EMBEDDING_MODEL, base_url=OLLAMA_BASE_URL ) + vectorstore = Chroma( collection_name=COLLECTION_NAME, persist_directory=CHROMA_PATH, @@ -285,9 +791,14 @@ async def main(): ) processor = ChunkProcessor(vectorstore) + analytics = LearningAnalytics(vectorstore) + commands = InteractiveCommands(vectorstore, analytics) + cache = load_hash_cache() - # Checking documents + # Index existing documents + console.print(f"\n[bold yellow]📚 Indexing documents...[/bold yellow]") + files = [ os.path.join(root, file) for root, _, files in os.walk(MD_DIRECTORY) @@ -299,49 +810,106 @@ async def main(): async with semaphore: return await processor.index_file(fp, cache) - tasks = [sem_task(fp) for fp in files] - for fut in asyncio.as_completed(tasks): - await fut + # Use progress bar for indexing + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console + ) as progress: + task = progress.add_task("Indexing files...", total=len(files)) + + tasks = [sem_task(fp) for fp in files] + for fut in asyncio.as_completed(tasks): + await fut + progress.advance(task) + save_hash_cache(cache) + # Start file watcher observer = start_watcher(processor, cache) memory = ConversationMemory() + # Show help hint + console.print(f"\n[dim]💡 Type /help to see available commands[/dim]\n") + try: while True: + # Get user input query = await session.prompt_async("> ", style=style) query = query.strip() - if query.lower() in {"exit", "quit", "q"}: - console.print("\nGoodbye!", style="yellow") - break - if not query: continue - + + if not query: + continue + + # Handle commands + if query.startswith('/'): + command = query[1:].lower().strip() + + if command in ['exit', 'quit', 'q']: + console.print("\n👋 Goodbye!", style="yellow") + break + + elif command in ['help', 'h']: + await show_help() + + elif command in ['stats', 'statistics']: + await commands.show_learning_stats() + + elif command in ['excluded', 'list-excluded']: + await commands.list_excluded_files() + + elif command in ['learning-mode', 'learn']: + await commands.interactive_learning_mode() + + elif command in ['suggest', 'suggestions']: + await commands.suggest_topics() + + elif command in ['exclude']: + await commands.exclude_file_interactive() + + elif command in ['reindex']: + console.print("\n[yellow]🔄 Reindexing all files...[/yellow]") + cache.clear() + for file_path in files: + await processor.index_file(file_path, cache) + save_hash_cache(cache) + console.print("[green]✓ Reindexing complete![/green]") + + else: + console.print(f"[red]✗ Unknown command: {command}[/red]") + console.print("[dim]Type /help to see available commands[/dim]") + + continue + + # Process normal queries console.print() - mode = classify_intent(query) history_str = memory.get_history() if mode == "SEARCH": - console.print("🔍 SEARCH MODE (Top-K)", style="bold blue") + console.print("🔍 SEARCH MODE (Top-K Retrieval)", style="bold blue") retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K}) docs = await asyncio.to_thread(retriever.invoke, query) - context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs) + context_str = "\n\n".join( + f"[{Path(d.metadata['source']).name}]\n{d.page_content}" + for d in docs + ) chain = get_chain(SYSTEM_PROMPT_SEARCH) - else: # ANALYSIS MODE - console.print("📊 ANALYSIS MODE (Full Context)", style="bold magenta") + elif mode == "ANALYSIS": + console.print("📊 ANALYSIS MODE (Full Context Evaluation)", style="bold magenta") - db_data = await asyncio.to_thread(vectorstore.get) + db_data = await asyncio.to_thread(vectorstore.get) all_texts = db_data['documents'] all_metas = db_data['metadatas'] if not all_texts: - console.print("No documents found to analyze!", style="red") + console.print("[red]No documents found to analyze![/red]") continue - # Exclude chunks where metadata has exclude: true + # Filter excluded chunks filtered_pairs = [ (text, meta) for text, meta in zip(all_texts, all_metas) if meta and not meta.get('exclude', False) @@ -352,15 +920,14 @@ async def main(): console.print(f"ℹ Excluded {excluded_count} chunks marked 'exclude: true'", style="dim") if not filtered_pairs: - console.print("All documents are marked for exclusion. Nothing to analyze.", style="yellow") + console.print("[yellow]All documents are marked for exclusion. Nothing to analyze.[/yellow]") continue + # Build context full_context = "" char_count = 0 - paired = sorted(filtered_pairs, key=lambda x: x[1]['source']) - - for text, meta in paired: + for text, meta in filtered_pairs[:25]: # Limit for analysis entry = f"\n---\nSource: {Path(meta['source']).name}\n{text}\n" if char_count + len(entry) > MAX_ANALYSIS_CONTEXT_CHARS: full_context += "\n[...Truncated due to context limit...]" @@ -372,6 +939,19 @@ async def main(): context_str = full_context chain = get_chain(SYSTEM_PROMPT_ANALYSIS) + elif mode == "SUGGEST": + await commands.suggest_topics() + continue + + elif mode == "STATS": + await commands.show_learning_stats() + continue + + elif mode == "LEARN": + await commands.interactive_learning_mode() + continue + + # Generate and display response response = "" console.print(f"Context size: {len(context_str)} chars", style="dim") console.print("Assistant:", style="blue", end=" ") @@ -385,20 +965,60 @@ async def main(): response += chunk console.print("\n") + # Update conversation memory memory.add("user", query) memory.add("assistant", response) finally: + # Cleanup observer.stop() observer.join() +async def show_help(): + """Display help information""" + console.print("\n[bold cyan]📖 Available Commands:[/bold cyan]") + console.print("=" * 50, style="dim") + + commands = [ + ("/help", "Show this help message"), + ("/stats", "Display learning statistics and progress"), + ("/learning-mode", "Start interactive learning analysis"), + ("/suggest", "Get topic suggestions for next study"), + ("/excluded", "List files excluded from analysis"), + ("/exclude", "Interactively exclude a file"), + ("/reindex", "Reindex all documents"), + ("/exit, /quit, /q", "Exit the application"), + ] + + for cmd, desc in commands: + console.print(f"[yellow]{cmd:<20}[/yellow] {desc}") + + console.print("\n[bold cyan]🎯 Learning Modes:[/bold cyan]") + console.print("=" * 50, style="dim") + console.print("• [blue]Search Mode[/blue]: Ask questions about your notes") + console.print("• [magenta]Analysis Mode[/magenta]: Get progress evaluation") + console.print("• [green]Suggestion Mode[/green]: Get topic recommendations") + + console.print("\n[bold cyan]💡 Examples:[/bold cyan]") + console.print("=" * 50, style="dim") + console.print("• \"What is SQL JOIN?\" → Search your notes") + console.print("• \"Assess my progress\" → Analyze learning") + console.print("• \"What should I learn next?\" → Get suggestions") + console.print("• \"Show my statistics\" → Display progress") + + console.print() + if __name__ == "__main__": import nest_asyncio nest_asyncio.apply() + try: import asyncio loop = asyncio.get_event_loop() loop.run_until_complete(main()) except KeyboardInterrupt: - console.print("\nGoodbye!", style="yellow") + console.print("\n👋 Goodbye!", style="yellow") sys.exit(0) + except Exception as e: + console.print(f"\n[red]✗ Unexpected error: {e}[/red]") + sys.exit(1) diff --git a/pyproject.toml b/pyproject.toml index 37231be..7bc095d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "nest-asyncio>=1.6.0", "prompt-toolkit>=3.0.52", "python-dotenv>=1.2.1", + "pyyaml>=6.0.3", "rich>=14.2.0", "unstructured[md]>=0.18.21", "watchdog>=6.0.0", diff --git a/uv.lock b/uv.lock index 6b71a68..b0b3841 100644 --- a/uv.lock +++ b/uv.lock @@ -2100,6 +2100,7 @@ dependencies = [ { name = "nest-asyncio" }, { name = "prompt-toolkit" }, { name = "python-dotenv" }, + { name = "pyyaml" }, { name = "rich" }, { name = "unstructured", extra = ["md"] }, { name = "watchdog" }, @@ -2114,6 +2115,7 @@ requires-dist = [ { name = "nest-asyncio", specifier = ">=1.6.0" }, { name = "prompt-toolkit", specifier = ">=3.0.52" }, { name = "python-dotenv", specifier = ">=1.2.1" }, + { name = "pyyaml", specifier = ">=6.0.3" }, { name = "rich", specifier = ">=14.2.0" }, { name = "unstructured", extras = ["md"], specifier = ">=0.18.21" }, { name = "watchdog", specifier = ">=6.0.0" },