diff --git a/main.py b/main.py index 3c169dc..ce314bf 100644 --- a/main.py +++ b/main.py @@ -4,14 +4,16 @@ import sys import json import hashlib import asyncio +import re from pathlib import Path from collections import deque -from typing import List, Dict +from typing import List, Dict, Tuple import torch from dotenv import load_dotenv from rich.console import Console from rich.panel import Panel +from rich.markdown import Markdown from prompt_toolkit import PromptSession from prompt_toolkit.styles import Style from prompt_toolkit.patch_stdout import patch_stdout @@ -36,23 +38,35 @@ load_dotenv() style = Style.from_dict({"prompt": "bold #6a0dad"}) -SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are a precise technical assistant. Cite sources using [filename]. Be concise.") +# --- PROMPTS --- +SYSTEM_PROMPT_SEARCH = os.getenv("SYSTEM_PROMPT", "You are a precise technical assistant. Cite sources using [filename]. Be concise.") +SYSTEM_PROMPT_ANALYSIS = ( + "You are an expert tutor and progress evaluator. " + "You have access to the student's entire knowledge base below. " + "Analyze the coverage, depth, and connections in the notes. " + "Identify what the user has learned well, what is missing, and suggest the next logical steps. " + "Do not just summarize; evaluate the progress." +) + USER_PROMPT_TEMPLATE = os.getenv("USER_PROMPT_TEMPLATE", "Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}") -MD_DIRECTORY = os.getenv("MD_FOLDER") -EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL") -LLM_MODEL = os.getenv("LLM_MODEL") +MD_DIRECTORY = os.getenv("MD_FOLDER", "./notes") +EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nomic-embed-text") +LLM_MODEL = os.getenv("LLM_MODEL", "llama3") CHROMA_PATH = "./.cache/chroma_db" HASH_CACHE = "./.cache/file_hashes.json" MAX_EMBED_CHARS = 380 -CHUNK_SIZE = 1200 +CHUNK_SIZE = 1200 CHUNK_OVERLAP = 200 TOP_K = 6 COLLECTION_NAME = "md_rag" +# Limit context size for Analysis mode (approx 24k chars ~ 6k tokens) to prevent OOM +MAX_ANALYSIS_CONTEXT_CHARS = 24000 + BATCH_SIZE = 10 MAX_PARALLEL_FILES = 3 @@ -62,25 +76,18 @@ MAX_PARALLEL_FILES = 3 def setup_gpu(): if torch.cuda.is_available(): torch.cuda.set_per_process_memory_fraction(0.95) - device_id = torch.cuda.current_device() device_name = torch.cuda.get_device_name(device_id) - - # VRAM info (in GB) total_vram = torch.cuda.get_device_properties(device_id).total_memory / (1024**3) - allocated = torch.cuda.memory_allocated(device_id) / (1024**3) - reserved = torch.cuda.memory_reserved(device_id) / (1024**3) - free = total_vram - reserved - - console.print(f"[green]✓ GPU: {device_name}[/green]") - console.print(f"[blue] VRAM: {total_vram:.1f}GB total | {free:.1f}GB free | {allocated:.1f}GB allocated[/blue]") + console.print(f"[green]✓ GPU: {device_name} ({total_vram:.1f}GB)[/green]\n") else: - console.print("[yellow]⚠ CPU mode[/yellow]") + console.print("[yellow]⚠ CPU mode[/yellow]\n") + console.print("\n") setup_gpu() # ========================= -# HASH CACHE +# UTILS & CACHE # ========================= def get_file_hash(file_path: str) -> str: return hashlib.md5(Path(file_path).read_bytes()).hexdigest() @@ -95,7 +102,27 @@ def save_hash_cache(cache: dict): Path(HASH_CACHE).write_text(json.dumps(cache, indent=2)) # ========================= -# CHUNK VALIDATION +# ROUTING LOGIC +# ========================= +def classify_intent(query: str) -> str: + """ + Determines if the user wants a specific search (RAG) or a global assessment. + """ + analysis_keywords = [ + r"assess my progress", r"eval(uate)? my (learning|knowledge)", + r"what have i learned", r"summary of (my )?notes", + r"my progress", r"learning path", r"knowledge gap", + r"оцени (мой )?прогресс", r"что я выучил", r"итоги", r"анализ знаний" + ] + + query_lower = query.lower() + for pattern in analysis_keywords: + if re.search(pattern, query_lower): + return "ANALYSIS" + return "SEARCH" + +# ========================= +# DOCUMENT PROCESSING # ========================= def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str]: if len(text) <= max_chars: @@ -109,8 +136,8 @@ def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str if len(current) + len(sentence) <= max_chars: current += sentence else: - if current: - chunks.append(current.strip()) + if current: chunks.append(current.strip()) + # Handle extremely long sentences by word splitting if len(sentence) > max_chars: words = sentence.split() temp = "" @@ -118,22 +145,16 @@ def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str if len(temp) + len(word) + 1 <= max_chars: temp += word + " " else: - if temp: - chunks.append(temp.strip()) + if temp: chunks.append(temp.strip()) temp = word + " " - if temp: - chunks.append(temp.strip()) + if temp: chunks.append(temp.strip()) + current = "" else: current = sentence - if current: - chunks.append(current.strip()) - + if current: chunks.append(current.strip()) return [c for c in chunks if c] -# ========================= -# DOCUMENT PROCESSING -# ========================= class ChunkProcessor: def __init__(self, vectorstore): self.vectorstore = vectorstore @@ -141,9 +162,7 @@ class ChunkProcessor: async def process_file(self, file_path: str) -> List[Dict]: try: - docs = await asyncio.to_thread( - UnstructuredMarkdownLoader(file_path).load - ) + docs = await asyncio.to_thread(UnstructuredMarkdownLoader(file_path).load) except Exception as e: console.print(f"[red]✗ {Path(file_path).name}: {e}[/red]") return [] @@ -167,40 +186,15 @@ class ChunkProcessor: return chunks async def embed_batch(self, batch: List[Dict]) -> bool: - if not batch: - return True - + if not batch: return True try: - docs = [Document(page_content=c["text"], metadata=c["metadata"]) - for c in batch] + docs = [Document(page_content=c["text"], metadata=c["metadata"]) for c in batch] ids = [c["id"] for c in batch] - - await asyncio.to_thread( - self.vectorstore.add_documents, - docs, - ids=ids - ) + await asyncio.to_thread(self.vectorstore.add_documents, docs, ids=ids) return True - except Exception as e: - error_msg = str(e).lower() - if "context length" in error_msg or "input length" in error_msg: - console.print(f"[yellow]⚠ Oversized chunk detected, processing individually[/yellow]") - for item in batch: - try: - doc = Document(page_content=item["text"], metadata=item["metadata"]) - await asyncio.to_thread( - self.vectorstore.add_documents, - [doc], - ids=[item["id"]] - ) - except Exception: - console.print(f"[red]✗ Skipping chunk (too large): {len(item['text'])} chars[/red]") - continue - return True - else: - console.print(f"[red]✗ Embed error: {e}[/red]") - return False + console.print(f"[red]✗ Embed error: {e}[/red]") + return False async def index_file(self, file_path: str, cache: dict) -> bool: async with self.semaphore: @@ -209,14 +203,16 @@ class ChunkProcessor: return False chunks = await self.process_file(file_path) - if not chunks: - return False + if not chunks: return False + + try: + self.vectorstore._collection.delete(where={"source": file_path}) + except: + pass # Collection might be empty for i in range(0, len(chunks), BATCH_SIZE): batch = chunks[i:i + BATCH_SIZE] - success = await self.embed_batch(batch) - if not success: - console.print(f"[yellow]⚠ Partial failure in {Path(file_path).name}[/yellow]") + await self.embed_batch(batch) cache[file_path] = current_hash console.print(f"[green]✓ {Path(file_path).name} ({len(chunks)} chunks)[/green]") @@ -256,7 +252,7 @@ def start_watcher(processor, cache): return observer # ========================= -# RAG CHAIN & MEMORY +# RAG CHAIN FACTORY # ========================= class ConversationMemory: def __init__(self, max_messages: int = 8): @@ -269,18 +265,15 @@ class ConversationMemory: self.messages.pop(0) def get_history(self) -> str: - if not self.messages: - return "No previous conversation." + if not self.messages: return "No previous conversation." return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages]) -def get_rag_components(retriever): - llm = ChatOllama(model=LLM_MODEL, temperature=0.1) - +def get_chain(system_prompt): + llm = ChatOllama(model=LLM_MODEL, temperature=0.2) prompt = ChatPromptTemplate.from_messages([ - ("system", SYSTEM_PROMPT), + ("system", system_prompt), ("human", USER_PROMPT_TEMPLATE) ]) - return prompt | llm | StrOutputParser() # ========================= @@ -291,7 +284,7 @@ async def main(): Path(CHROMA_PATH).parent.mkdir(parents=True, exist_ok=True) console.print(Panel.fit( - f"[bold cyan]⚡ RAG System[/bold cyan]\n" + f"[bold cyan]⚡ Dual-Mode RAG System[/bold cyan]\n" f"📂 Docs: {MD_DIRECTORY}\n" f"🧠 Embed: {EMBEDDING_MODEL}\n" f"🤖 LLM: {LLM_MODEL}", @@ -308,13 +301,14 @@ async def main(): processor = ChunkProcessor(vectorstore) cache = load_hash_cache() - console.print("\n[yellow]Indexing documents...[/yellow]") + console.print("\n[yellow]Checking documents...[/yellow]") files = [ os.path.join(root, file) for root, _, files in os.walk(MD_DIRECTORY) for file in files if file.endswith(".md") ] + # Initial Indexing semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES) async def sem_task(fp): async with semaphore: @@ -325,19 +319,10 @@ async def main(): await fut save_hash_cache(cache) - console.print(f"[green]✓ Processed {len(files)} files[/green]\n") - observer = start_watcher(processor, cache) - - retriever = vectorstore.as_retriever( - search_type="similarity", - search_kwargs={"k": TOP_K} - ) - - rag_chain = get_rag_components(retriever) memory = ConversationMemory() - console.print("[bold green]💬 Ready![/bold green]\n") + console.print("[bold green]💬 Ready! Type 'exit' to quit.[/bold green]\n") try: with patch_stdout(): @@ -347,15 +332,57 @@ async def main(): if query.lower() in {"exit", "quit", "q"}: print("Goodbye!") break - if not query: - continue + if not query: continue - docs = await asyncio.to_thread(retriever.invoke, query) - context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs) + mode = classify_intent(query) history_str = memory.get_history() + if mode == "SEARCH": + console.print("[bold blue]🔍 SEARCH MODE (Top-K)[/bold blue]") + + # Standard RAG + retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K}) + docs = await asyncio.to_thread(retriever.invoke, query) + context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs) + + chain = get_chain(SYSTEM_PROMPT_SEARCH) + + else: # ANALYSIS MODE + console.print("[bold magenta]📊 ANALYSIS MODE (Full Context)[/bold magenta]") + + # Fetch ALL documents (limited by size) + # Chroma .get() returns dict with keys: ids, embeddings, documents, metadatas + db_data = await asyncio.to_thread(vectorstore.get) + all_texts = db_data['documents'] + all_metas = db_data['metadatas'] + + if not all_texts: + console.print("[red]No documents found to analyze![/red]") + continue + + # Concatenate content for analysis + full_context = "" + char_count = 0 + + # Sort arbitrarily or by source to group files + paired = sorted(zip(all_texts, all_metas), key=lambda x: x[1]['source']) + + for text, meta in paired: + entry = f"\n---\nSource: {Path(meta['source']).name}\n{text}\n" + if char_count + len(entry) > MAX_ANALYSIS_CONTEXT_CHARS: + full_context += "\n[...Truncated due to context limit...]" + console.print("[yellow]⚠ Context limit reached, truncating analysis data.[/yellow]") + break + full_context += entry + char_count += len(entry) + + context_str = full_context + chain = get_chain(SYSTEM_PROMPT_ANALYSIS) + response = "" - async for chunk in rag_chain.astream({ + console.print(f"[dim]Context size: {len(context_str)} chars[/dim]") + + async for chunk in chain.astream({ "context": context_str, "question": query, "history": history_str