Files
rag-llm/main.py
2025-12-29 18:30:08 +03:00

369 lines
12 KiB
Python

#!/usr/bin/env python3
import os
import sys
import json
import hashlib
import asyncio
from pathlib import Path
from collections import deque
from typing import List, Dict
import torch
from dotenv import load_dotenv
from rich.console import Console
from rich.panel import Panel
from prompt_toolkit import PromptSession
from prompt_toolkit.styles import Style
from prompt_toolkit.patch_stdout import patch_stdout
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_ollama import OllamaEmbeddings, ChatOllama
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
# =========================
# CONFIG
# =========================
console = Console()
session = PromptSession()
load_dotenv()
style = Style.from_dict({"prompt": "bold #6a0dad"})
MD_DIRECTORY = os.getenv("MD_FOLDER")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
LLM_MODEL = os.getenv("LLM_MODEL")
CHROMA_PATH = "./.cache/chroma_db"
HASH_CACHE = "./.cache/file_hashes.json"
MAX_EMBED_CHARS = 380
CHUNK_SIZE = 1200
CHUNK_OVERLAP = 200
TOP_K = 6
COLLECTION_NAME = "md_rag"
BATCH_SIZE = 10
MAX_PARALLEL_FILES = 3
# =========================
# GPU SETUP
# =========================
def setup_gpu():
if torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(0.95)
else:
console.print("[yellow]⚠ CPU mode[/yellow]")
setup_gpu()
# =========================
# HASH CACHE
# =========================
def get_file_hash(file_path: str) -> str:
return hashlib.md5(Path(file_path).read_bytes()).hexdigest()
def load_hash_cache() -> dict:
Path(HASH_CACHE).parent.mkdir(parents=True, exist_ok=True)
if Path(HASH_CACHE).exists():
return json.loads(Path(HASH_CACHE).read_text())
return {}
def save_hash_cache(cache: dict):
Path(HASH_CACHE).write_text(json.dumps(cache, indent=2))
# =========================
# CHUNK VALIDATION
# =========================
def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str]:
if len(text) <= max_chars:
return [text]
sentences = text.replace('. ', '.|').replace('! ', '!|').replace('? ', '?|').split('|')
chunks = []
current = ""
for sentence in sentences:
if len(current) + len(sentence) <= max_chars:
current += sentence
else:
if current:
chunks.append(current.strip())
if len(sentence) > max_chars:
words = sentence.split()
temp = ""
for word in words:
if len(temp) + len(word) + 1 <= max_chars:
temp += word + " "
else:
if temp:
chunks.append(temp.strip())
temp = word + " "
if temp:
chunks.append(temp.strip())
else:
current = sentence
if current:
chunks.append(current.strip())
return [c for c in chunks if c]
# =========================
# DOCUMENT PROCESSING
# =========================
class ChunkProcessor:
def __init__(self, vectorstore):
self.vectorstore = vectorstore
self.semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
async def process_file(self, file_path: str) -> List[Dict]:
try:
docs = await asyncio.to_thread(
UnstructuredMarkdownLoader(file_path).load
)
except Exception as e:
console.print(f"[red]✗ {Path(file_path).name}: {e}[/red]")
return []
splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE,
chunk_overlap=CHUNK_OVERLAP,
separators=["\n\n", "\n", ". ", " "]
)
chunks = []
for doc_idx, doc in enumerate(docs):
for chunk_idx, text in enumerate(splitter.split_text(doc.page_content)):
safe_texts = validate_chunk_size(text)
for sub_idx, safe_text in enumerate(safe_texts):
chunks.append({
"id": f"{file_path}::{doc_idx}::{chunk_idx}::{sub_idx}",
"text": safe_text,
"metadata": {"source": file_path, **doc.metadata}
})
return chunks
async def embed_batch(self, batch: List[Dict]) -> bool:
if not batch:
return True
try:
docs = [Document(page_content=c["text"], metadata=c["metadata"])
for c in batch]
ids = [c["id"] for c in batch]
await asyncio.to_thread(
self.vectorstore.add_documents,
docs,
ids=ids
)
return True
except Exception as e:
error_msg = str(e).lower()
if "context length" in error_msg or "input length" in error_msg:
console.print(f"[yellow]⚠ Oversized chunk detected, processing individually[/yellow]")
for item in batch:
try:
doc = Document(page_content=item["text"], metadata=item["metadata"])
await asyncio.to_thread(
self.vectorstore.add_documents,
[doc],
ids=[item["id"]]
)
except Exception:
console.print(f"[red]✗ Skipping chunk (too large): {len(item['text'])} chars[/red]")
continue
return True
else:
console.print(f"[red]✗ Embed error: {e}[/red]")
return False
async def index_file(self, file_path: str, cache: dict) -> bool:
async with self.semaphore:
current_hash = get_file_hash(file_path)
if cache.get(file_path) == current_hash:
return False
chunks = await self.process_file(file_path)
if not chunks:
return False
for i in range(0, len(chunks), BATCH_SIZE):
batch = chunks[i:i + BATCH_SIZE]
success = await self.embed_batch(batch)
if not success:
console.print(f"[yellow]⚠ Partial failure in {Path(file_path).name}[/yellow]")
cache[file_path] = current_hash
console.print(f"[green]✓ {Path(file_path).name} ({len(chunks)} chunks)[/green]")
return True
# =========================
# FILE WATCHER
# =========================
class DocumentWatcher(FileSystemEventHandler):
def __init__(self, processor, cache):
self.processor = processor
self.cache = cache
self.queue = deque()
self.processing = False
def on_modified(self, event):
if not event.is_directory and event.src_path.endswith(".md"):
self.queue.append(event.src_path)
async def process_queue(self):
while True:
if self.queue and not self.processing:
self.processing = True
file_path = self.queue.popleft()
if Path(file_path).exists():
await self.processor.index_file(file_path, self.cache)
save_hash_cache(self.cache)
self.processing = False
await asyncio.sleep(1)
def start_watcher(processor, cache):
handler = DocumentWatcher(processor, cache)
observer = Observer()
observer.schedule(handler, MD_DIRECTORY, recursive=True)
observer.start()
asyncio.create_task(handler.process_queue())
return observer
# =========================
# RAG CHAIN & MEMORY
# =========================
class ConversationMemory:
def __init__(self, max_messages: int = 8):
self.messages = []
self.max_messages = max_messages
def add(self, role: str, content: str):
self.messages.append({"role": role, "content": content})
if len(self.messages) > self.max_messages:
self.messages.pop(0)
def get_history(self) -> str:
if not self.messages:
return "No previous conversation."
return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages])
def get_rag_components(retriever):
llm = ChatOllama(model=LLM_MODEL, temperature=0.1)
# FIX 1: Added {history} to the prompt
prompt = ChatPromptTemplate.from_messages([
("system", "You are a precise technical assistant. Cite sources using [filename]. Be concise."),
("human", "Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}")
])
return prompt | llm | StrOutputParser()
# =========================
# MAIN
# =========================
async def main():
Path(MD_DIRECTORY).mkdir(parents=True, exist_ok=True)
Path(CHROMA_PATH).parent.mkdir(parents=True, exist_ok=True)
console.print(Panel.fit(
f"[bold cyan]⚡ RAG System[/bold cyan]\n"
f"📂 Docs: {MD_DIRECTORY}\n"
f"🧠 Embed: {EMBEDDING_MODEL}\n"
f"🤖 LLM: {LLM_MODEL}",
border_style="cyan"
))
embeddings = OllamaEmbeddings(model=EMBEDDING_MODEL)
vectorstore = Chroma(
collection_name=COLLECTION_NAME,
persist_directory=CHROMA_PATH,
embedding_function=embeddings
)
processor = ChunkProcessor(vectorstore)
cache = load_hash_cache()
console.print("\n[yellow]Indexing documents...[/yellow]")
files = [
os.path.join(root, file)
for root, _, files in os.walk(MD_DIRECTORY)
for file in files if file.endswith(".md")
]
semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
async def sem_task(fp):
async with semaphore:
return await processor.index_file(fp, cache)
tasks = [sem_task(fp) for fp in files]
for fut in asyncio.as_completed(tasks):
await fut
save_hash_cache(cache)
console.print(f"[green]✓ Processed {len(files)} files[/green]\n")
observer = start_watcher(processor, cache)
retriever = vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": TOP_K}
)
rag_chain = get_rag_components(retriever)
memory = ConversationMemory()
console.print("[bold green]💬 Ready![/bold green]\n")
try:
with patch_stdout():
while True:
query = await session.prompt_async("> ", style=style)
query = query.strip()
if query.lower() in {"exit", "quit", "q"}:
print("Goodbye!")
break
if not query:
continue
docs = await asyncio.to_thread(retriever.invoke, query)
context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs)
history_str = memory.get_history()
response = ""
async for chunk in rag_chain.astream({
"context": context_str,
"question": query,
"history": history_str
}):
print(chunk, end="")
response += chunk
console.print("\n")
memory.add("user", query)
memory.add("assistant", response)
finally:
observer.stop()
observer.join()
if __name__ == "__main__":
import nest_asyncio
nest_asyncio.apply()
try:
import asyncio
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
except KeyboardInterrupt:
console.print("\n[yellow]Goodbye![/yellow]")
sys.exit(0)