#!/usr/bin/env python3 import os import sys import json import hashlib import asyncio from pathlib import Path from collections import deque from typing import List, Dict import torch from dotenv import load_dotenv from rich.console import Console from rich.panel import Panel from prompt_toolkit import PromptSession from prompt_toolkit.styles import Style from prompt_toolkit.patch_stdout import patch_stdout from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_ollama import OllamaEmbeddings, ChatOllama from langchain_chroma import Chroma from langchain_core.documents import Document from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from watchdog.observers import Observer from watchdog.events import FileSystemEventHandler # ========================= # CONFIG # ========================= console = Console() session = PromptSession() load_dotenv() style = Style.from_dict({"prompt": "bold #6a0dad"}) SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are a precise technical assistant. Cite sources using [filename]. Be concise.") USER_PROMPT_TEMPLATE = os.getenv("USER_PROMPT_TEMPLATE", "Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}") MD_DIRECTORY = os.getenv("MD_FOLDER") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL") LLM_MODEL = os.getenv("LLM_MODEL") CHROMA_PATH = "./.cache/chroma_db" HASH_CACHE = "./.cache/file_hashes.json" MAX_EMBED_CHARS = 380 CHUNK_SIZE = 1200 CHUNK_OVERLAP = 200 TOP_K = 6 COLLECTION_NAME = "md_rag" BATCH_SIZE = 10 MAX_PARALLEL_FILES = 3 # ========================= # GPU SETUP # ========================= def setup_gpu(): if torch.cuda.is_available(): torch.cuda.set_per_process_memory_fraction(0.95) device_id = torch.cuda.current_device() device_name = torch.cuda.get_device_name(device_id) # VRAM info (in GB) total_vram = torch.cuda.get_device_properties(device_id).total_memory / (1024**3) allocated = torch.cuda.memory_allocated(device_id) / (1024**3) reserved = torch.cuda.memory_reserved(device_id) / (1024**3) free = total_vram - reserved console.print(f"[green]✓ GPU: {device_name}[/green]") console.print(f"[blue] VRAM: {total_vram:.1f}GB total | {free:.1f}GB free | {allocated:.1f}GB allocated[/blue]") else: console.print("[yellow]⚠ CPU mode[/yellow]") setup_gpu() # ========================= # HASH CACHE # ========================= def get_file_hash(file_path: str) -> str: return hashlib.md5(Path(file_path).read_bytes()).hexdigest() def load_hash_cache() -> dict: Path(HASH_CACHE).parent.mkdir(parents=True, exist_ok=True) if Path(HASH_CACHE).exists(): return json.loads(Path(HASH_CACHE).read_text()) return {} def save_hash_cache(cache: dict): Path(HASH_CACHE).write_text(json.dumps(cache, indent=2)) # ========================= # CHUNK VALIDATION # ========================= def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str]: if len(text) <= max_chars: return [text] sentences = text.replace('. ', '.|').replace('! ', '!|').replace('? ', '?|').split('|') chunks = [] current = "" for sentence in sentences: if len(current) + len(sentence) <= max_chars: current += sentence else: if current: chunks.append(current.strip()) if len(sentence) > max_chars: words = sentence.split() temp = "" for word in words: if len(temp) + len(word) + 1 <= max_chars: temp += word + " " else: if temp: chunks.append(temp.strip()) temp = word + " " if temp: chunks.append(temp.strip()) else: current = sentence if current: chunks.append(current.strip()) return [c for c in chunks if c] # ========================= # DOCUMENT PROCESSING # ========================= class ChunkProcessor: def __init__(self, vectorstore): self.vectorstore = vectorstore self.semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES) async def process_file(self, file_path: str) -> List[Dict]: try: docs = await asyncio.to_thread( UnstructuredMarkdownLoader(file_path).load ) except Exception as e: console.print(f"[red]✗ {Path(file_path).name}: {e}[/red]") return [] splitter = RecursiveCharacterTextSplitter( chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, separators=["\n\n", "\n", ". ", " "] ) chunks = [] for doc_idx, doc in enumerate(docs): for chunk_idx, text in enumerate(splitter.split_text(doc.page_content)): safe_texts = validate_chunk_size(text) for sub_idx, safe_text in enumerate(safe_texts): chunks.append({ "id": f"{file_path}::{doc_idx}::{chunk_idx}::{sub_idx}", "text": safe_text, "metadata": {"source": file_path, **doc.metadata} }) return chunks async def embed_batch(self, batch: List[Dict]) -> bool: if not batch: return True try: docs = [Document(page_content=c["text"], metadata=c["metadata"]) for c in batch] ids = [c["id"] for c in batch] await asyncio.to_thread( self.vectorstore.add_documents, docs, ids=ids ) return True except Exception as e: error_msg = str(e).lower() if "context length" in error_msg or "input length" in error_msg: console.print(f"[yellow]⚠ Oversized chunk detected, processing individually[/yellow]") for item in batch: try: doc = Document(page_content=item["text"], metadata=item["metadata"]) await asyncio.to_thread( self.vectorstore.add_documents, [doc], ids=[item["id"]] ) except Exception: console.print(f"[red]✗ Skipping chunk (too large): {len(item['text'])} chars[/red]") continue return True else: console.print(f"[red]✗ Embed error: {e}[/red]") return False async def index_file(self, file_path: str, cache: dict) -> bool: async with self.semaphore: current_hash = get_file_hash(file_path) if cache.get(file_path) == current_hash: return False chunks = await self.process_file(file_path) if not chunks: return False for i in range(0, len(chunks), BATCH_SIZE): batch = chunks[i:i + BATCH_SIZE] success = await self.embed_batch(batch) if not success: console.print(f"[yellow]⚠ Partial failure in {Path(file_path).name}[/yellow]") cache[file_path] = current_hash console.print(f"[green]✓ {Path(file_path).name} ({len(chunks)} chunks)[/green]") return True # ========================= # FILE WATCHER # ========================= class DocumentWatcher(FileSystemEventHandler): def __init__(self, processor, cache): self.processor = processor self.cache = cache self.queue = deque() self.processing = False def on_modified(self, event): if not event.is_directory and event.src_path.endswith(".md"): self.queue.append(event.src_path) async def process_queue(self): while True: if self.queue and not self.processing: self.processing = True file_path = self.queue.popleft() if Path(file_path).exists(): await self.processor.index_file(file_path, self.cache) save_hash_cache(self.cache) self.processing = False await asyncio.sleep(1) def start_watcher(processor, cache): handler = DocumentWatcher(processor, cache) observer = Observer() observer.schedule(handler, MD_DIRECTORY, recursive=True) observer.start() asyncio.create_task(handler.process_queue()) return observer # ========================= # RAG CHAIN & MEMORY # ========================= class ConversationMemory: def __init__(self, max_messages: int = 8): self.messages = [] self.max_messages = max_messages def add(self, role: str, content: str): self.messages.append({"role": role, "content": content}) if len(self.messages) > self.max_messages: self.messages.pop(0) def get_history(self) -> str: if not self.messages: return "No previous conversation." return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages]) def get_rag_components(retriever): llm = ChatOllama(model=LLM_MODEL, temperature=0.1) prompt = ChatPromptTemplate.from_messages([ ("system", SYSTEM_PROMPT), ("human", USER_PROMPT_TEMPLATE) ]) return prompt | llm | StrOutputParser() # ========================= # MAIN # ========================= async def main(): Path(MD_DIRECTORY).mkdir(parents=True, exist_ok=True) Path(CHROMA_PATH).parent.mkdir(parents=True, exist_ok=True) console.print(Panel.fit( f"[bold cyan]⚡ RAG System[/bold cyan]\n" f"📂 Docs: {MD_DIRECTORY}\n" f"🧠 Embed: {EMBEDDING_MODEL}\n" f"🤖 LLM: {LLM_MODEL}", border_style="cyan" )) embeddings = OllamaEmbeddings(model=EMBEDDING_MODEL) vectorstore = Chroma( collection_name=COLLECTION_NAME, persist_directory=CHROMA_PATH, embedding_function=embeddings ) processor = ChunkProcessor(vectorstore) cache = load_hash_cache() console.print("\n[yellow]Indexing documents...[/yellow]") files = [ os.path.join(root, file) for root, _, files in os.walk(MD_DIRECTORY) for file in files if file.endswith(".md") ] semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES) async def sem_task(fp): async with semaphore: return await processor.index_file(fp, cache) tasks = [sem_task(fp) for fp in files] for fut in asyncio.as_completed(tasks): await fut save_hash_cache(cache) console.print(f"[green]✓ Processed {len(files)} files[/green]\n") observer = start_watcher(processor, cache) retriever = vectorstore.as_retriever( search_type="similarity", search_kwargs={"k": TOP_K} ) rag_chain = get_rag_components(retriever) memory = ConversationMemory() console.print("[bold green]💬 Ready![/bold green]\n") try: with patch_stdout(): while True: query = await session.prompt_async("> ", style=style) query = query.strip() if query.lower() in {"exit", "quit", "q"}: print("Goodbye!") break if not query: continue docs = await asyncio.to_thread(retriever.invoke, query) context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs) history_str = memory.get_history() response = "" async for chunk in rag_chain.astream({ "context": context_str, "question": query, "history": history_str }): print(chunk, end="") response += chunk console.print("\n") memory.add("user", query) memory.add("assistant", response) finally: observer.stop() observer.join() if __name__ == "__main__": import nest_asyncio nest_asyncio.apply() try: import asyncio loop = asyncio.get_event_loop() loop.run_until_complete(main()) except KeyboardInterrupt: console.print("\n[yellow]Goodbye![/yellow]") sys.exit(0)