feat: dual-mode rag system
This commit is contained in:
207
main.py
207
main.py
@@ -4,14 +4,16 @@ import sys
|
|||||||
import json
|
import json
|
||||||
import hashlib
|
import hashlib
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import List, Dict
|
from typing import List, Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
|
from rich.markdown import Markdown
|
||||||
from prompt_toolkit import PromptSession
|
from prompt_toolkit import PromptSession
|
||||||
from prompt_toolkit.styles import Style
|
from prompt_toolkit.styles import Style
|
||||||
from prompt_toolkit.patch_stdout import patch_stdout
|
from prompt_toolkit.patch_stdout import patch_stdout
|
||||||
@@ -36,13 +38,22 @@ load_dotenv()
|
|||||||
|
|
||||||
style = Style.from_dict({"prompt": "bold #6a0dad"})
|
style = Style.from_dict({"prompt": "bold #6a0dad"})
|
||||||
|
|
||||||
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are a precise technical assistant. Cite sources using [filename]. Be concise.")
|
# --- PROMPTS ---
|
||||||
|
SYSTEM_PROMPT_SEARCH = os.getenv("SYSTEM_PROMPT", "You are a precise technical assistant. Cite sources using [filename]. Be concise.")
|
||||||
|
SYSTEM_PROMPT_ANALYSIS = (
|
||||||
|
"You are an expert tutor and progress evaluator. "
|
||||||
|
"You have access to the student's entire knowledge base below. "
|
||||||
|
"Analyze the coverage, depth, and connections in the notes. "
|
||||||
|
"Identify what the user has learned well, what is missing, and suggest the next logical steps. "
|
||||||
|
"Do not just summarize; evaluate the progress."
|
||||||
|
)
|
||||||
|
|
||||||
USER_PROMPT_TEMPLATE = os.getenv("USER_PROMPT_TEMPLATE",
|
USER_PROMPT_TEMPLATE = os.getenv("USER_PROMPT_TEMPLATE",
|
||||||
"Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}")
|
"Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}")
|
||||||
|
|
||||||
MD_DIRECTORY = os.getenv("MD_FOLDER")
|
MD_DIRECTORY = os.getenv("MD_FOLDER", "./notes")
|
||||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nomic-embed-text")
|
||||||
LLM_MODEL = os.getenv("LLM_MODEL")
|
LLM_MODEL = os.getenv("LLM_MODEL", "llama3")
|
||||||
|
|
||||||
CHROMA_PATH = "./.cache/chroma_db"
|
CHROMA_PATH = "./.cache/chroma_db"
|
||||||
HASH_CACHE = "./.cache/file_hashes.json"
|
HASH_CACHE = "./.cache/file_hashes.json"
|
||||||
@@ -53,6 +64,9 @@ CHUNK_OVERLAP = 200
|
|||||||
TOP_K = 6
|
TOP_K = 6
|
||||||
COLLECTION_NAME = "md_rag"
|
COLLECTION_NAME = "md_rag"
|
||||||
|
|
||||||
|
# Limit context size for Analysis mode (approx 24k chars ~ 6k tokens) to prevent OOM
|
||||||
|
MAX_ANALYSIS_CONTEXT_CHARS = 24000
|
||||||
|
|
||||||
BATCH_SIZE = 10
|
BATCH_SIZE = 10
|
||||||
MAX_PARALLEL_FILES = 3
|
MAX_PARALLEL_FILES = 3
|
||||||
|
|
||||||
@@ -62,25 +76,18 @@ MAX_PARALLEL_FILES = 3
|
|||||||
def setup_gpu():
|
def setup_gpu():
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.set_per_process_memory_fraction(0.95)
|
torch.cuda.set_per_process_memory_fraction(0.95)
|
||||||
|
|
||||||
device_id = torch.cuda.current_device()
|
device_id = torch.cuda.current_device()
|
||||||
device_name = torch.cuda.get_device_name(device_id)
|
device_name = torch.cuda.get_device_name(device_id)
|
||||||
|
|
||||||
# VRAM info (in GB)
|
|
||||||
total_vram = torch.cuda.get_device_properties(device_id).total_memory / (1024**3)
|
total_vram = torch.cuda.get_device_properties(device_id).total_memory / (1024**3)
|
||||||
allocated = torch.cuda.memory_allocated(device_id) / (1024**3)
|
console.print(f"[green]✓ GPU: {device_name} ({total_vram:.1f}GB)[/green]\n")
|
||||||
reserved = torch.cuda.memory_reserved(device_id) / (1024**3)
|
|
||||||
free = total_vram - reserved
|
|
||||||
|
|
||||||
console.print(f"[green]✓ GPU: {device_name}[/green]")
|
|
||||||
console.print(f"[blue] VRAM: {total_vram:.1f}GB total | {free:.1f}GB free | {allocated:.1f}GB allocated[/blue]")
|
|
||||||
else:
|
else:
|
||||||
console.print("[yellow]⚠ CPU mode[/yellow]")
|
console.print("[yellow]⚠ CPU mode[/yellow]\n")
|
||||||
|
console.print("\n")
|
||||||
|
|
||||||
setup_gpu()
|
setup_gpu()
|
||||||
|
|
||||||
# =========================
|
# =========================
|
||||||
# HASH CACHE
|
# UTILS & CACHE
|
||||||
# =========================
|
# =========================
|
||||||
def get_file_hash(file_path: str) -> str:
|
def get_file_hash(file_path: str) -> str:
|
||||||
return hashlib.md5(Path(file_path).read_bytes()).hexdigest()
|
return hashlib.md5(Path(file_path).read_bytes()).hexdigest()
|
||||||
@@ -95,7 +102,27 @@ def save_hash_cache(cache: dict):
|
|||||||
Path(HASH_CACHE).write_text(json.dumps(cache, indent=2))
|
Path(HASH_CACHE).write_text(json.dumps(cache, indent=2))
|
||||||
|
|
||||||
# =========================
|
# =========================
|
||||||
# CHUNK VALIDATION
|
# ROUTING LOGIC
|
||||||
|
# =========================
|
||||||
|
def classify_intent(query: str) -> str:
|
||||||
|
"""
|
||||||
|
Determines if the user wants a specific search (RAG) or a global assessment.
|
||||||
|
"""
|
||||||
|
analysis_keywords = [
|
||||||
|
r"assess my progress", r"eval(uate)? my (learning|knowledge)",
|
||||||
|
r"what have i learned", r"summary of (my )?notes",
|
||||||
|
r"my progress", r"learning path", r"knowledge gap",
|
||||||
|
r"оцени (мой )?прогресс", r"что я выучил", r"итоги", r"анализ знаний"
|
||||||
|
]
|
||||||
|
|
||||||
|
query_lower = query.lower()
|
||||||
|
for pattern in analysis_keywords:
|
||||||
|
if re.search(pattern, query_lower):
|
||||||
|
return "ANALYSIS"
|
||||||
|
return "SEARCH"
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# DOCUMENT PROCESSING
|
||||||
# =========================
|
# =========================
|
||||||
def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str]:
|
def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str]:
|
||||||
if len(text) <= max_chars:
|
if len(text) <= max_chars:
|
||||||
@@ -109,8 +136,8 @@ def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str
|
|||||||
if len(current) + len(sentence) <= max_chars:
|
if len(current) + len(sentence) <= max_chars:
|
||||||
current += sentence
|
current += sentence
|
||||||
else:
|
else:
|
||||||
if current:
|
if current: chunks.append(current.strip())
|
||||||
chunks.append(current.strip())
|
# Handle extremely long sentences by word splitting
|
||||||
if len(sentence) > max_chars:
|
if len(sentence) > max_chars:
|
||||||
words = sentence.split()
|
words = sentence.split()
|
||||||
temp = ""
|
temp = ""
|
||||||
@@ -118,22 +145,16 @@ def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str
|
|||||||
if len(temp) + len(word) + 1 <= max_chars:
|
if len(temp) + len(word) + 1 <= max_chars:
|
||||||
temp += word + " "
|
temp += word + " "
|
||||||
else:
|
else:
|
||||||
if temp:
|
if temp: chunks.append(temp.strip())
|
||||||
chunks.append(temp.strip())
|
|
||||||
temp = word + " "
|
temp = word + " "
|
||||||
if temp:
|
if temp: chunks.append(temp.strip())
|
||||||
chunks.append(temp.strip())
|
current = ""
|
||||||
else:
|
else:
|
||||||
current = sentence
|
current = sentence
|
||||||
|
|
||||||
if current:
|
if current: chunks.append(current.strip())
|
||||||
chunks.append(current.strip())
|
|
||||||
|
|
||||||
return [c for c in chunks if c]
|
return [c for c in chunks if c]
|
||||||
|
|
||||||
# =========================
|
|
||||||
# DOCUMENT PROCESSING
|
|
||||||
# =========================
|
|
||||||
class ChunkProcessor:
|
class ChunkProcessor:
|
||||||
def __init__(self, vectorstore):
|
def __init__(self, vectorstore):
|
||||||
self.vectorstore = vectorstore
|
self.vectorstore = vectorstore
|
||||||
@@ -141,9 +162,7 @@ class ChunkProcessor:
|
|||||||
|
|
||||||
async def process_file(self, file_path: str) -> List[Dict]:
|
async def process_file(self, file_path: str) -> List[Dict]:
|
||||||
try:
|
try:
|
||||||
docs = await asyncio.to_thread(
|
docs = await asyncio.to_thread(UnstructuredMarkdownLoader(file_path).load)
|
||||||
UnstructuredMarkdownLoader(file_path).load
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
console.print(f"[red]✗ {Path(file_path).name}: {e}[/red]")
|
console.print(f"[red]✗ {Path(file_path).name}: {e}[/red]")
|
||||||
return []
|
return []
|
||||||
@@ -167,38 +186,13 @@ class ChunkProcessor:
|
|||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
async def embed_batch(self, batch: List[Dict]) -> bool:
|
async def embed_batch(self, batch: List[Dict]) -> bool:
|
||||||
if not batch:
|
if not batch: return True
|
||||||
return True
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
docs = [Document(page_content=c["text"], metadata=c["metadata"])
|
docs = [Document(page_content=c["text"], metadata=c["metadata"]) for c in batch]
|
||||||
for c in batch]
|
|
||||||
ids = [c["id"] for c in batch]
|
ids = [c["id"] for c in batch]
|
||||||
|
await asyncio.to_thread(self.vectorstore.add_documents, docs, ids=ids)
|
||||||
await asyncio.to_thread(
|
|
||||||
self.vectorstore.add_documents,
|
|
||||||
docs,
|
|
||||||
ids=ids
|
|
||||||
)
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e).lower()
|
|
||||||
if "context length" in error_msg or "input length" in error_msg:
|
|
||||||
console.print(f"[yellow]⚠ Oversized chunk detected, processing individually[/yellow]")
|
|
||||||
for item in batch:
|
|
||||||
try:
|
|
||||||
doc = Document(page_content=item["text"], metadata=item["metadata"])
|
|
||||||
await asyncio.to_thread(
|
|
||||||
self.vectorstore.add_documents,
|
|
||||||
[doc],
|
|
||||||
ids=[item["id"]]
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
console.print(f"[red]✗ Skipping chunk (too large): {len(item['text'])} chars[/red]")
|
|
||||||
continue
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
console.print(f"[red]✗ Embed error: {e}[/red]")
|
console.print(f"[red]✗ Embed error: {e}[/red]")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -209,14 +203,16 @@ class ChunkProcessor:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
chunks = await self.process_file(file_path)
|
chunks = await self.process_file(file_path)
|
||||||
if not chunks:
|
if not chunks: return False
|
||||||
return False
|
|
||||||
|
try:
|
||||||
|
self.vectorstore._collection.delete(where={"source": file_path})
|
||||||
|
except:
|
||||||
|
pass # Collection might be empty
|
||||||
|
|
||||||
for i in range(0, len(chunks), BATCH_SIZE):
|
for i in range(0, len(chunks), BATCH_SIZE):
|
||||||
batch = chunks[i:i + BATCH_SIZE]
|
batch = chunks[i:i + BATCH_SIZE]
|
||||||
success = await self.embed_batch(batch)
|
await self.embed_batch(batch)
|
||||||
if not success:
|
|
||||||
console.print(f"[yellow]⚠ Partial failure in {Path(file_path).name}[/yellow]")
|
|
||||||
|
|
||||||
cache[file_path] = current_hash
|
cache[file_path] = current_hash
|
||||||
console.print(f"[green]✓ {Path(file_path).name} ({len(chunks)} chunks)[/green]")
|
console.print(f"[green]✓ {Path(file_path).name} ({len(chunks)} chunks)[/green]")
|
||||||
@@ -256,7 +252,7 @@ def start_watcher(processor, cache):
|
|||||||
return observer
|
return observer
|
||||||
|
|
||||||
# =========================
|
# =========================
|
||||||
# RAG CHAIN & MEMORY
|
# RAG CHAIN FACTORY
|
||||||
# =========================
|
# =========================
|
||||||
class ConversationMemory:
|
class ConversationMemory:
|
||||||
def __init__(self, max_messages: int = 8):
|
def __init__(self, max_messages: int = 8):
|
||||||
@@ -269,18 +265,15 @@ class ConversationMemory:
|
|||||||
self.messages.pop(0)
|
self.messages.pop(0)
|
||||||
|
|
||||||
def get_history(self) -> str:
|
def get_history(self) -> str:
|
||||||
if not self.messages:
|
if not self.messages: return "No previous conversation."
|
||||||
return "No previous conversation."
|
|
||||||
return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages])
|
return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages])
|
||||||
|
|
||||||
def get_rag_components(retriever):
|
def get_chain(system_prompt):
|
||||||
llm = ChatOllama(model=LLM_MODEL, temperature=0.1)
|
llm = ChatOllama(model=LLM_MODEL, temperature=0.2)
|
||||||
|
|
||||||
prompt = ChatPromptTemplate.from_messages([
|
prompt = ChatPromptTemplate.from_messages([
|
||||||
("system", SYSTEM_PROMPT),
|
("system", system_prompt),
|
||||||
("human", USER_PROMPT_TEMPLATE)
|
("human", USER_PROMPT_TEMPLATE)
|
||||||
])
|
])
|
||||||
|
|
||||||
return prompt | llm | StrOutputParser()
|
return prompt | llm | StrOutputParser()
|
||||||
|
|
||||||
# =========================
|
# =========================
|
||||||
@@ -291,7 +284,7 @@ async def main():
|
|||||||
Path(CHROMA_PATH).parent.mkdir(parents=True, exist_ok=True)
|
Path(CHROMA_PATH).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
console.print(Panel.fit(
|
console.print(Panel.fit(
|
||||||
f"[bold cyan]⚡ RAG System[/bold cyan]\n"
|
f"[bold cyan]⚡ Dual-Mode RAG System[/bold cyan]\n"
|
||||||
f"📂 Docs: {MD_DIRECTORY}\n"
|
f"📂 Docs: {MD_DIRECTORY}\n"
|
||||||
f"🧠 Embed: {EMBEDDING_MODEL}\n"
|
f"🧠 Embed: {EMBEDDING_MODEL}\n"
|
||||||
f"🤖 LLM: {LLM_MODEL}",
|
f"🤖 LLM: {LLM_MODEL}",
|
||||||
@@ -308,13 +301,14 @@ async def main():
|
|||||||
processor = ChunkProcessor(vectorstore)
|
processor = ChunkProcessor(vectorstore)
|
||||||
cache = load_hash_cache()
|
cache = load_hash_cache()
|
||||||
|
|
||||||
console.print("\n[yellow]Indexing documents...[/yellow]")
|
console.print("\n[yellow]Checking documents...[/yellow]")
|
||||||
files = [
|
files = [
|
||||||
os.path.join(root, file)
|
os.path.join(root, file)
|
||||||
for root, _, files in os.walk(MD_DIRECTORY)
|
for root, _, files in os.walk(MD_DIRECTORY)
|
||||||
for file in files if file.endswith(".md")
|
for file in files if file.endswith(".md")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Initial Indexing
|
||||||
semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
|
semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
|
||||||
async def sem_task(fp):
|
async def sem_task(fp):
|
||||||
async with semaphore:
|
async with semaphore:
|
||||||
@@ -325,19 +319,10 @@ async def main():
|
|||||||
await fut
|
await fut
|
||||||
save_hash_cache(cache)
|
save_hash_cache(cache)
|
||||||
|
|
||||||
console.print(f"[green]✓ Processed {len(files)} files[/green]\n")
|
|
||||||
|
|
||||||
observer = start_watcher(processor, cache)
|
observer = start_watcher(processor, cache)
|
||||||
|
|
||||||
retriever = vectorstore.as_retriever(
|
|
||||||
search_type="similarity",
|
|
||||||
search_kwargs={"k": TOP_K}
|
|
||||||
)
|
|
||||||
|
|
||||||
rag_chain = get_rag_components(retriever)
|
|
||||||
memory = ConversationMemory()
|
memory = ConversationMemory()
|
||||||
|
|
||||||
console.print("[bold green]💬 Ready![/bold green]\n")
|
console.print("[bold green]💬 Ready! Type 'exit' to quit.[/bold green]\n")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with patch_stdout():
|
with patch_stdout():
|
||||||
@@ -347,15 +332,57 @@ async def main():
|
|||||||
if query.lower() in {"exit", "quit", "q"}:
|
if query.lower() in {"exit", "quit", "q"}:
|
||||||
print("Goodbye!")
|
print("Goodbye!")
|
||||||
break
|
break
|
||||||
if not query:
|
if not query: continue
|
||||||
continue
|
|
||||||
|
|
||||||
docs = await asyncio.to_thread(retriever.invoke, query)
|
mode = classify_intent(query)
|
||||||
context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs)
|
|
||||||
history_str = memory.get_history()
|
history_str = memory.get_history()
|
||||||
|
|
||||||
|
if mode == "SEARCH":
|
||||||
|
console.print("[bold blue]🔍 SEARCH MODE (Top-K)[/bold blue]")
|
||||||
|
|
||||||
|
# Standard RAG
|
||||||
|
retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K})
|
||||||
|
docs = await asyncio.to_thread(retriever.invoke, query)
|
||||||
|
context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs)
|
||||||
|
|
||||||
|
chain = get_chain(SYSTEM_PROMPT_SEARCH)
|
||||||
|
|
||||||
|
else: # ANALYSIS MODE
|
||||||
|
console.print("[bold magenta]📊 ANALYSIS MODE (Full Context)[/bold magenta]")
|
||||||
|
|
||||||
|
# Fetch ALL documents (limited by size)
|
||||||
|
# Chroma .get() returns dict with keys: ids, embeddings, documents, metadatas
|
||||||
|
db_data = await asyncio.to_thread(vectorstore.get)
|
||||||
|
all_texts = db_data['documents']
|
||||||
|
all_metas = db_data['metadatas']
|
||||||
|
|
||||||
|
if not all_texts:
|
||||||
|
console.print("[red]No documents found to analyze![/red]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Concatenate content for analysis
|
||||||
|
full_context = ""
|
||||||
|
char_count = 0
|
||||||
|
|
||||||
|
# Sort arbitrarily or by source to group files
|
||||||
|
paired = sorted(zip(all_texts, all_metas), key=lambda x: x[1]['source'])
|
||||||
|
|
||||||
|
for text, meta in paired:
|
||||||
|
entry = f"\n---\nSource: {Path(meta['source']).name}\n{text}\n"
|
||||||
|
if char_count + len(entry) > MAX_ANALYSIS_CONTEXT_CHARS:
|
||||||
|
full_context += "\n[...Truncated due to context limit...]"
|
||||||
|
console.print("[yellow]⚠ Context limit reached, truncating analysis data.[/yellow]")
|
||||||
|
break
|
||||||
|
full_context += entry
|
||||||
|
char_count += len(entry)
|
||||||
|
|
||||||
|
context_str = full_context
|
||||||
|
chain = get_chain(SYSTEM_PROMPT_ANALYSIS)
|
||||||
|
|
||||||
response = ""
|
response = ""
|
||||||
async for chunk in rag_chain.astream({
|
console.print(f"[dim]Context size: {len(context_str)} chars[/dim]")
|
||||||
|
|
||||||
|
async for chunk in chain.astream({
|
||||||
"context": context_str,
|
"context": context_str,
|
||||||
"question": query,
|
"question": query,
|
||||||
"history": history_str
|
"history": history_str
|
||||||
|
|||||||
Reference in New Issue
Block a user