feat: dual-mode rag system

This commit is contained in:
2025-12-30 03:24:40 +03:00
parent 0e4e438cbc
commit 19d4fe09b6

213
main.py
View File

@@ -4,14 +4,16 @@ import sys
import json import json
import hashlib import hashlib
import asyncio import asyncio
import re
from pathlib import Path from pathlib import Path
from collections import deque from collections import deque
from typing import List, Dict from typing import List, Dict, Tuple
import torch import torch
from dotenv import load_dotenv from dotenv import load_dotenv
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
from rich.markdown import Markdown
from prompt_toolkit import PromptSession from prompt_toolkit import PromptSession
from prompt_toolkit.styles import Style from prompt_toolkit.styles import Style
from prompt_toolkit.patch_stdout import patch_stdout from prompt_toolkit.patch_stdout import patch_stdout
@@ -36,23 +38,35 @@ load_dotenv()
style = Style.from_dict({"prompt": "bold #6a0dad"}) style = Style.from_dict({"prompt": "bold #6a0dad"})
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are a precise technical assistant. Cite sources using [filename]. Be concise.") # --- PROMPTS ---
SYSTEM_PROMPT_SEARCH = os.getenv("SYSTEM_PROMPT", "You are a precise technical assistant. Cite sources using [filename]. Be concise.")
SYSTEM_PROMPT_ANALYSIS = (
"You are an expert tutor and progress evaluator. "
"You have access to the student's entire knowledge base below. "
"Analyze the coverage, depth, and connections in the notes. "
"Identify what the user has learned well, what is missing, and suggest the next logical steps. "
"Do not just summarize; evaluate the progress."
)
USER_PROMPT_TEMPLATE = os.getenv("USER_PROMPT_TEMPLATE", USER_PROMPT_TEMPLATE = os.getenv("USER_PROMPT_TEMPLATE",
"Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}") "Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}")
MD_DIRECTORY = os.getenv("MD_FOLDER") MD_DIRECTORY = os.getenv("MD_FOLDER", "./notes")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nomic-embed-text")
LLM_MODEL = os.getenv("LLM_MODEL") LLM_MODEL = os.getenv("LLM_MODEL", "llama3")
CHROMA_PATH = "./.cache/chroma_db" CHROMA_PATH = "./.cache/chroma_db"
HASH_CACHE = "./.cache/file_hashes.json" HASH_CACHE = "./.cache/file_hashes.json"
MAX_EMBED_CHARS = 380 MAX_EMBED_CHARS = 380
CHUNK_SIZE = 1200 CHUNK_SIZE = 1200
CHUNK_OVERLAP = 200 CHUNK_OVERLAP = 200
TOP_K = 6 TOP_K = 6
COLLECTION_NAME = "md_rag" COLLECTION_NAME = "md_rag"
# Limit context size for Analysis mode (approx 24k chars ~ 6k tokens) to prevent OOM
MAX_ANALYSIS_CONTEXT_CHARS = 24000
BATCH_SIZE = 10 BATCH_SIZE = 10
MAX_PARALLEL_FILES = 3 MAX_PARALLEL_FILES = 3
@@ -62,25 +76,18 @@ MAX_PARALLEL_FILES = 3
def setup_gpu(): def setup_gpu():
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(0.95) torch.cuda.set_per_process_memory_fraction(0.95)
device_id = torch.cuda.current_device() device_id = torch.cuda.current_device()
device_name = torch.cuda.get_device_name(device_id) device_name = torch.cuda.get_device_name(device_id)
# VRAM info (in GB)
total_vram = torch.cuda.get_device_properties(device_id).total_memory / (1024**3) total_vram = torch.cuda.get_device_properties(device_id).total_memory / (1024**3)
allocated = torch.cuda.memory_allocated(device_id) / (1024**3) console.print(f"[green]✓ GPU: {device_name} ({total_vram:.1f}GB)[/green]\n")
reserved = torch.cuda.memory_reserved(device_id) / (1024**3)
free = total_vram - reserved
console.print(f"[green]✓ GPU: {device_name}[/green]")
console.print(f"[blue] VRAM: {total_vram:.1f}GB total | {free:.1f}GB free | {allocated:.1f}GB allocated[/blue]")
else: else:
console.print("[yellow]⚠ CPU mode[/yellow]") console.print("[yellow]⚠ CPU mode[/yellow]\n")
console.print("\n")
setup_gpu() setup_gpu()
# ========================= # =========================
# HASH CACHE # UTILS & CACHE
# ========================= # =========================
def get_file_hash(file_path: str) -> str: def get_file_hash(file_path: str) -> str:
return hashlib.md5(Path(file_path).read_bytes()).hexdigest() return hashlib.md5(Path(file_path).read_bytes()).hexdigest()
@@ -95,7 +102,27 @@ def save_hash_cache(cache: dict):
Path(HASH_CACHE).write_text(json.dumps(cache, indent=2)) Path(HASH_CACHE).write_text(json.dumps(cache, indent=2))
# ========================= # =========================
# CHUNK VALIDATION # ROUTING LOGIC
# =========================
def classify_intent(query: str) -> str:
"""
Determines if the user wants a specific search (RAG) or a global assessment.
"""
analysis_keywords = [
r"assess my progress", r"eval(uate)? my (learning|knowledge)",
r"what have i learned", r"summary of (my )?notes",
r"my progress", r"learning path", r"knowledge gap",
r"оцени (мой )?прогресс", r"что я выучил", r"итоги", r"анализ знаний"
]
query_lower = query.lower()
for pattern in analysis_keywords:
if re.search(pattern, query_lower):
return "ANALYSIS"
return "SEARCH"
# =========================
# DOCUMENT PROCESSING
# ========================= # =========================
def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str]: def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str]:
if len(text) <= max_chars: if len(text) <= max_chars:
@@ -109,8 +136,8 @@ def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str
if len(current) + len(sentence) <= max_chars: if len(current) + len(sentence) <= max_chars:
current += sentence current += sentence
else: else:
if current: if current: chunks.append(current.strip())
chunks.append(current.strip()) # Handle extremely long sentences by word splitting
if len(sentence) > max_chars: if len(sentence) > max_chars:
words = sentence.split() words = sentence.split()
temp = "" temp = ""
@@ -118,22 +145,16 @@ def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str
if len(temp) + len(word) + 1 <= max_chars: if len(temp) + len(word) + 1 <= max_chars:
temp += word + " " temp += word + " "
else: else:
if temp: if temp: chunks.append(temp.strip())
chunks.append(temp.strip())
temp = word + " " temp = word + " "
if temp: if temp: chunks.append(temp.strip())
chunks.append(temp.strip()) current = ""
else: else:
current = sentence current = sentence
if current: if current: chunks.append(current.strip())
chunks.append(current.strip())
return [c for c in chunks if c] return [c for c in chunks if c]
# =========================
# DOCUMENT PROCESSING
# =========================
class ChunkProcessor: class ChunkProcessor:
def __init__(self, vectorstore): def __init__(self, vectorstore):
self.vectorstore = vectorstore self.vectorstore = vectorstore
@@ -141,9 +162,7 @@ class ChunkProcessor:
async def process_file(self, file_path: str) -> List[Dict]: async def process_file(self, file_path: str) -> List[Dict]:
try: try:
docs = await asyncio.to_thread( docs = await asyncio.to_thread(UnstructuredMarkdownLoader(file_path).load)
UnstructuredMarkdownLoader(file_path).load
)
except Exception as e: except Exception as e:
console.print(f"[red]✗ {Path(file_path).name}: {e}[/red]") console.print(f"[red]✗ {Path(file_path).name}: {e}[/red]")
return [] return []
@@ -167,40 +186,15 @@ class ChunkProcessor:
return chunks return chunks
async def embed_batch(self, batch: List[Dict]) -> bool: async def embed_batch(self, batch: List[Dict]) -> bool:
if not batch: if not batch: return True
return True
try: try:
docs = [Document(page_content=c["text"], metadata=c["metadata"]) docs = [Document(page_content=c["text"], metadata=c["metadata"]) for c in batch]
for c in batch]
ids = [c["id"] for c in batch] ids = [c["id"] for c in batch]
await asyncio.to_thread(self.vectorstore.add_documents, docs, ids=ids)
await asyncio.to_thread(
self.vectorstore.add_documents,
docs,
ids=ids
)
return True return True
except Exception as e: except Exception as e:
error_msg = str(e).lower() console.print(f"[red]✗ Embed error: {e}[/red]")
if "context length" in error_msg or "input length" in error_msg: return False
console.print(f"[yellow]⚠ Oversized chunk detected, processing individually[/yellow]")
for item in batch:
try:
doc = Document(page_content=item["text"], metadata=item["metadata"])
await asyncio.to_thread(
self.vectorstore.add_documents,
[doc],
ids=[item["id"]]
)
except Exception:
console.print(f"[red]✗ Skipping chunk (too large): {len(item['text'])} chars[/red]")
continue
return True
else:
console.print(f"[red]✗ Embed error: {e}[/red]")
return False
async def index_file(self, file_path: str, cache: dict) -> bool: async def index_file(self, file_path: str, cache: dict) -> bool:
async with self.semaphore: async with self.semaphore:
@@ -209,14 +203,16 @@ class ChunkProcessor:
return False return False
chunks = await self.process_file(file_path) chunks = await self.process_file(file_path)
if not chunks: if not chunks: return False
return False
try:
self.vectorstore._collection.delete(where={"source": file_path})
except:
pass # Collection might be empty
for i in range(0, len(chunks), BATCH_SIZE): for i in range(0, len(chunks), BATCH_SIZE):
batch = chunks[i:i + BATCH_SIZE] batch = chunks[i:i + BATCH_SIZE]
success = await self.embed_batch(batch) await self.embed_batch(batch)
if not success:
console.print(f"[yellow]⚠ Partial failure in {Path(file_path).name}[/yellow]")
cache[file_path] = current_hash cache[file_path] = current_hash
console.print(f"[green]✓ {Path(file_path).name} ({len(chunks)} chunks)[/green]") console.print(f"[green]✓ {Path(file_path).name} ({len(chunks)} chunks)[/green]")
@@ -256,7 +252,7 @@ def start_watcher(processor, cache):
return observer return observer
# ========================= # =========================
# RAG CHAIN & MEMORY # RAG CHAIN FACTORY
# ========================= # =========================
class ConversationMemory: class ConversationMemory:
def __init__(self, max_messages: int = 8): def __init__(self, max_messages: int = 8):
@@ -269,18 +265,15 @@ class ConversationMemory:
self.messages.pop(0) self.messages.pop(0)
def get_history(self) -> str: def get_history(self) -> str:
if not self.messages: if not self.messages: return "No previous conversation."
return "No previous conversation."
return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages]) return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages])
def get_rag_components(retriever): def get_chain(system_prompt):
llm = ChatOllama(model=LLM_MODEL, temperature=0.1) llm = ChatOllama(model=LLM_MODEL, temperature=0.2)
prompt = ChatPromptTemplate.from_messages([ prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT), ("system", system_prompt),
("human", USER_PROMPT_TEMPLATE) ("human", USER_PROMPT_TEMPLATE)
]) ])
return prompt | llm | StrOutputParser() return prompt | llm | StrOutputParser()
# ========================= # =========================
@@ -291,7 +284,7 @@ async def main():
Path(CHROMA_PATH).parent.mkdir(parents=True, exist_ok=True) Path(CHROMA_PATH).parent.mkdir(parents=True, exist_ok=True)
console.print(Panel.fit( console.print(Panel.fit(
f"[bold cyan]⚡ RAG System[/bold cyan]\n" f"[bold cyan]⚡ Dual-Mode RAG System[/bold cyan]\n"
f"📂 Docs: {MD_DIRECTORY}\n" f"📂 Docs: {MD_DIRECTORY}\n"
f"🧠 Embed: {EMBEDDING_MODEL}\n" f"🧠 Embed: {EMBEDDING_MODEL}\n"
f"🤖 LLM: {LLM_MODEL}", f"🤖 LLM: {LLM_MODEL}",
@@ -308,13 +301,14 @@ async def main():
processor = ChunkProcessor(vectorstore) processor = ChunkProcessor(vectorstore)
cache = load_hash_cache() cache = load_hash_cache()
console.print("\n[yellow]Indexing documents...[/yellow]") console.print("\n[yellow]Checking documents...[/yellow]")
files = [ files = [
os.path.join(root, file) os.path.join(root, file)
for root, _, files in os.walk(MD_DIRECTORY) for root, _, files in os.walk(MD_DIRECTORY)
for file in files if file.endswith(".md") for file in files if file.endswith(".md")
] ]
# Initial Indexing
semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES) semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
async def sem_task(fp): async def sem_task(fp):
async with semaphore: async with semaphore:
@@ -325,19 +319,10 @@ async def main():
await fut await fut
save_hash_cache(cache) save_hash_cache(cache)
console.print(f"[green]✓ Processed {len(files)} files[/green]\n")
observer = start_watcher(processor, cache) observer = start_watcher(processor, cache)
retriever = vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": TOP_K}
)
rag_chain = get_rag_components(retriever)
memory = ConversationMemory() memory = ConversationMemory()
console.print("[bold green]💬 Ready![/bold green]\n") console.print("[bold green]💬 Ready! Type 'exit' to quit.[/bold green]\n")
try: try:
with patch_stdout(): with patch_stdout():
@@ -347,15 +332,57 @@ async def main():
if query.lower() in {"exit", "quit", "q"}: if query.lower() in {"exit", "quit", "q"}:
print("Goodbye!") print("Goodbye!")
break break
if not query: if not query: continue
continue
docs = await asyncio.to_thread(retriever.invoke, query) mode = classify_intent(query)
context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs)
history_str = memory.get_history() history_str = memory.get_history()
if mode == "SEARCH":
console.print("[bold blue]🔍 SEARCH MODE (Top-K)[/bold blue]")
# Standard RAG
retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K})
docs = await asyncio.to_thread(retriever.invoke, query)
context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs)
chain = get_chain(SYSTEM_PROMPT_SEARCH)
else: # ANALYSIS MODE
console.print("[bold magenta]📊 ANALYSIS MODE (Full Context)[/bold magenta]")
# Fetch ALL documents (limited by size)
# Chroma .get() returns dict with keys: ids, embeddings, documents, metadatas
db_data = await asyncio.to_thread(vectorstore.get)
all_texts = db_data['documents']
all_metas = db_data['metadatas']
if not all_texts:
console.print("[red]No documents found to analyze![/red]")
continue
# Concatenate content for analysis
full_context = ""
char_count = 0
# Sort arbitrarily or by source to group files
paired = sorted(zip(all_texts, all_metas), key=lambda x: x[1]['source'])
for text, meta in paired:
entry = f"\n---\nSource: {Path(meta['source']).name}\n{text}\n"
if char_count + len(entry) > MAX_ANALYSIS_CONTEXT_CHARS:
full_context += "\n[...Truncated due to context limit...]"
console.print("[yellow]⚠ Context limit reached, truncating analysis data.[/yellow]")
break
full_context += entry
char_count += len(entry)
context_str = full_context
chain = get_chain(SYSTEM_PROMPT_ANALYSIS)
response = "" response = ""
async for chunk in rag_chain.astream({ console.print(f"[dim]Context size: {len(context_str)} chars[/dim]")
async for chunk in chain.astream({
"context": context_str, "context": context_str,
"question": query, "question": query,
"history": history_str "history": history_str