405 lines
14 KiB
Python
405 lines
14 KiB
Python
#!/usr/bin/env python3
|
||
import os
|
||
import sys
|
||
import json
|
||
import hashlib
|
||
import asyncio
|
||
import re
|
||
from pathlib import Path
|
||
from collections import deque
|
||
from typing import List, Dict
|
||
|
||
from dotenv import load_dotenv
|
||
from rich.console import Console
|
||
from rich.panel import Panel
|
||
from prompt_toolkit import PromptSession
|
||
from prompt_toolkit.styles import Style
|
||
|
||
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||
from langchain_ollama import OllamaEmbeddings, ChatOllama
|
||
from langchain_chroma import Chroma
|
||
from langchain_core.documents import Document
|
||
from langchain_core.prompts import ChatPromptTemplate
|
||
from langchain_core.output_parsers import StrOutputParser
|
||
|
||
from watchdog.observers import Observer
|
||
from watchdog.events import FileSystemEventHandler
|
||
|
||
# =========================
|
||
# CONFIG
|
||
# =========================
|
||
console = Console(color_system="standard", force_terminal=True)
|
||
session = PromptSession()
|
||
load_dotenv()
|
||
|
||
style = Style.from_dict({"prompt": "bold #6a0dad"})
|
||
|
||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
||
ANSWER_COLOR = os.getenv("ANSWER_COLOR", "blue")
|
||
|
||
SYSTEM_PROMPT_SEARCH = os.getenv("SYSTEM_PROMPT", "You are a precise technical assistant. Cite sources using [filename]. Be concise.")
|
||
SYSTEM_PROMPT_ANALYSIS = (
|
||
"You are an expert tutor and progress evaluator. "
|
||
"You have access to the student's entire knowledge base below. "
|
||
"Analyze the coverage, depth, and connections in the notes. "
|
||
"Identify what the user has learned well, what is missing, and suggest the next logical steps. "
|
||
"Do not just summarize; evaluate the progress."
|
||
)
|
||
|
||
USER_PROMPT_TEMPLATE = os.getenv("USER_PROMPT_TEMPLATE",
|
||
"Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}")
|
||
|
||
MD_DIRECTORY = os.getenv("MD_FOLDER", "./notes")
|
||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nomic-embed-text")
|
||
LLM_MODEL = os.getenv("LLM_MODEL", "llama3")
|
||
|
||
CHROMA_PATH = "./.cache/chroma_db"
|
||
HASH_CACHE = "./.cache/file_hashes.json"
|
||
|
||
MAX_EMBED_CHARS = 380
|
||
CHUNK_SIZE = 1200
|
||
CHUNK_OVERLAP = 200
|
||
TOP_K = 6
|
||
COLLECTION_NAME = "md_rag"
|
||
|
||
MAX_ANALYSIS_CONTEXT_CHARS = 24000
|
||
|
||
BATCH_SIZE = 10
|
||
MAX_PARALLEL_FILES = 3
|
||
|
||
# =========================
|
||
# UTILS & CACHE
|
||
# =========================
|
||
def get_file_hash(file_path: str) -> str:
|
||
return hashlib.md5(Path(file_path).read_bytes()).hexdigest()
|
||
|
||
def load_hash_cache() -> dict:
|
||
Path(HASH_CACHE).parent.mkdir(parents=True, exist_ok=True)
|
||
if Path(HASH_CACHE).exists():
|
||
return json.loads(Path(HASH_CACHE).read_text())
|
||
return {}
|
||
|
||
def save_hash_cache(cache: dict):
|
||
Path(HASH_CACHE).write_text(json.dumps(cache, indent=2))
|
||
|
||
# =========================
|
||
# ROUTING LOGIC
|
||
# =========================
|
||
def classify_intent(query: str) -> str:
|
||
analysis_keywords = [
|
||
r"assess my progress", r"eval(uate)? my (learning|knowledge)",
|
||
r"what have i learned", r"summary of (my )?notes",
|
||
r"my progress", r"learning path", r"knowledge gap",
|
||
r"оцени (мой )?прогресс", r"что я выучил", r"итоги", r"анализ знаний",
|
||
r"сегодня(?:\s+\w+)*\s*урок", r"что я изучил"
|
||
]
|
||
|
||
query_lower = query.lower()
|
||
for pattern in analysis_keywords:
|
||
if re.search(pattern, query_lower):
|
||
return "ANALYSIS"
|
||
return "SEARCH"
|
||
|
||
# =========================
|
||
# DOCUMENT PROCESSING
|
||
# =========================
|
||
def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str]:
|
||
if len(text) <= max_chars:
|
||
return [text]
|
||
|
||
sentences = text.replace('. ', '.|').replace('! ', '!|').replace('? ', '?|').split('|')
|
||
chunks = []
|
||
current = ""
|
||
|
||
for sentence in sentences:
|
||
if len(current) + len(sentence) <= max_chars:
|
||
current += sentence
|
||
else:
|
||
if current: chunks.append(current.strip())
|
||
if len(sentence) > max_chars:
|
||
words = sentence.split()
|
||
temp = ""
|
||
for word in words:
|
||
if len(temp) + len(word) + 1 <= max_chars:
|
||
temp += word + " "
|
||
else:
|
||
if temp: chunks.append(temp.strip())
|
||
temp = word + " "
|
||
if temp: chunks.append(temp.strip())
|
||
current = ""
|
||
else:
|
||
current = sentence
|
||
|
||
if current: chunks.append(current.strip())
|
||
return [c for c in chunks if c]
|
||
|
||
class ChunkProcessor:
|
||
def __init__(self, vectorstore):
|
||
self.vectorstore = vectorstore
|
||
self.semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
|
||
|
||
async def process_file(self, file_path: str) -> List[Dict]:
|
||
try:
|
||
docs = await asyncio.to_thread(UnstructuredMarkdownLoader(file_path).load)
|
||
except Exception as e:
|
||
console.print(f"✗ {Path(file_path).name}: {e}", style="red")
|
||
return []
|
||
|
||
splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=CHUNK_SIZE,
|
||
chunk_overlap=CHUNK_OVERLAP,
|
||
separators=["\n\n", "\n", ". ", " "]
|
||
)
|
||
|
||
chunks = []
|
||
for doc_idx, doc in enumerate(docs):
|
||
for chunk_idx, text in enumerate(splitter.split_text(doc.page_content)):
|
||
safe_texts = validate_chunk_size(text)
|
||
for sub_idx, safe_text in enumerate(safe_texts):
|
||
chunks.append({
|
||
"id": f"{file_path}::{doc_idx}::{chunk_idx}::{sub_idx}",
|
||
"text": safe_text,
|
||
"metadata": {"source": file_path, **doc.metadata}
|
||
})
|
||
return chunks
|
||
|
||
async def embed_batch(self, batch: List[Dict]) -> bool:
|
||
if not batch: return True
|
||
try:
|
||
docs = [Document(page_content=c["text"], metadata=c["metadata"]) for c in batch]
|
||
ids = [c["id"] for c in batch]
|
||
await asyncio.to_thread(self.vectorstore.add_documents, docs, ids=ids)
|
||
return True
|
||
except Exception as e:
|
||
console.print(f"✗ Embed error: {e}", style="red")
|
||
return False
|
||
|
||
async def index_file(self, file_path: str, cache: dict) -> bool:
|
||
async with self.semaphore:
|
||
current_hash = get_file_hash(file_path)
|
||
if cache.get(file_path) == current_hash:
|
||
return False
|
||
|
||
chunks = await self.process_file(file_path)
|
||
if not chunks: return False
|
||
|
||
try:
|
||
self.vectorstore._collection.delete(where={"source": file_path})
|
||
except:
|
||
pass
|
||
|
||
for i in range(0, len(chunks), BATCH_SIZE):
|
||
batch = chunks[i:i + BATCH_SIZE]
|
||
await self.embed_batch(batch)
|
||
|
||
cache[file_path] = current_hash
|
||
console.print(f"✓ {Path(file_path).name} ({len(chunks)} chunks)", style="green")
|
||
return True
|
||
|
||
# =========================
|
||
# FILE WATCHER
|
||
# =========================
|
||
class DocumentWatcher(FileSystemEventHandler):
|
||
def __init__(self, processor, cache):
|
||
self.processor = processor
|
||
self.cache = cache
|
||
self.queue = deque()
|
||
self.processing = False
|
||
|
||
def on_modified(self, event):
|
||
if not event.is_directory and event.src_path.endswith(".md"):
|
||
self.queue.append(event.src_path)
|
||
|
||
async def process_queue(self):
|
||
while True:
|
||
if self.queue and not self.processing:
|
||
self.processing = True
|
||
file_path = self.queue.popleft()
|
||
if Path(file_path).exists():
|
||
await self.processor.index_file(file_path, self.cache)
|
||
save_hash_cache(self.cache)
|
||
self.processing = False
|
||
await asyncio.sleep(1)
|
||
|
||
def start_watcher(processor, cache):
|
||
handler = DocumentWatcher(processor, cache)
|
||
observer = Observer()
|
||
observer.schedule(handler, MD_DIRECTORY, recursive=True)
|
||
observer.start()
|
||
asyncio.create_task(handler.process_queue())
|
||
return observer
|
||
|
||
# =========================
|
||
# RAG CHAIN FACTORY
|
||
# =========================
|
||
class ConversationMemory:
|
||
def __init__(self, max_messages: int = 8):
|
||
self.messages = []
|
||
self.max_messages = max_messages
|
||
|
||
def add(self, role: str, content: str):
|
||
self.messages.append({"role": role, "content": content})
|
||
if len(self.messages) > self.max_messages:
|
||
self.messages.pop(0)
|
||
|
||
def get_history(self) -> str:
|
||
if not self.messages: return "No previous conversation."
|
||
return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages])
|
||
|
||
def get_chain(system_prompt):
|
||
llm = ChatOllama(
|
||
model=LLM_MODEL,
|
||
temperature=0.2,
|
||
base_url=OLLAMA_BASE_URL
|
||
)
|
||
prompt = ChatPromptTemplate.from_messages([
|
||
("system", system_prompt),
|
||
("human", USER_PROMPT_TEMPLATE)
|
||
])
|
||
return prompt | llm | StrOutputParser()
|
||
|
||
# =========================
|
||
# MAIN
|
||
# =========================
|
||
async def main():
|
||
Path(MD_DIRECTORY).mkdir(parents=True, exist_ok=True)
|
||
Path(CHROMA_PATH).parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
console.print(Panel.fit(
|
||
f"[bold cyan]⚡ Dual-Mode RAG System[/bold cyan]\n"
|
||
f"📂 Docs: {MD_DIRECTORY}\n"
|
||
f"🧠 Embed: {EMBEDDING_MODEL}\n"
|
||
f"🤖 LLM: {LLM_MODEL}",
|
||
border_style="cyan"
|
||
))
|
||
|
||
embeddings = OllamaEmbeddings(
|
||
model=EMBEDDING_MODEL,
|
||
base_url=OLLAMA_BASE_URL
|
||
)
|
||
vectorstore = Chroma(
|
||
collection_name=COLLECTION_NAME,
|
||
persist_directory=CHROMA_PATH,
|
||
embedding_function=embeddings
|
||
)
|
||
|
||
processor = ChunkProcessor(vectorstore)
|
||
cache = load_hash_cache()
|
||
|
||
# Checking documents
|
||
files = [
|
||
os.path.join(root, file)
|
||
for root, _, files in os.walk(MD_DIRECTORY)
|
||
for file in files if file.endswith(".md")
|
||
]
|
||
|
||
semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
|
||
async def sem_task(fp):
|
||
async with semaphore:
|
||
return await processor.index_file(fp, cache)
|
||
|
||
tasks = [sem_task(fp) for fp in files]
|
||
for fut in asyncio.as_completed(tasks):
|
||
await fut
|
||
save_hash_cache(cache)
|
||
|
||
observer = start_watcher(processor, cache)
|
||
memory = ConversationMemory()
|
||
|
||
try:
|
||
while True:
|
||
query = await session.prompt_async("> ", style=style)
|
||
query = query.strip()
|
||
if query.lower() in {"exit", "quit", "q"}:
|
||
console.print("\nGoodbye!", style="yellow")
|
||
break
|
||
if not query: continue
|
||
|
||
console.print()
|
||
|
||
mode = classify_intent(query)
|
||
history_str = memory.get_history()
|
||
|
||
if mode == "SEARCH":
|
||
console.print("🔍 SEARCH MODE (Top-K)", style="bold blue")
|
||
|
||
retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K})
|
||
docs = await asyncio.to_thread(retriever.invoke, query)
|
||
context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs)
|
||
|
||
chain = get_chain(SYSTEM_PROMPT_SEARCH)
|
||
|
||
else: # ANALYSIS MODE
|
||
console.print("📊 ANALYSIS MODE (Full Context)", style="bold magenta")
|
||
|
||
db_data = await asyncio.to_thread(vectorstore.get)
|
||
all_texts = db_data['documents']
|
||
all_metas = db_data['metadatas']
|
||
|
||
if not all_texts:
|
||
console.print("No documents found to analyze!", style="red")
|
||
continue
|
||
|
||
# Exclude chunks where metadata has exclude: true
|
||
filtered_pairs = [
|
||
(text, meta) for text, meta in zip(all_texts, all_metas)
|
||
if meta and not meta.get('exclude', False)
|
||
]
|
||
|
||
excluded_count = len(all_texts) - len(filtered_pairs)
|
||
if excluded_count > 0:
|
||
console.print(f"ℹ Excluded {excluded_count} chunks marked 'exclude: true'", style="dim")
|
||
|
||
if not filtered_pairs:
|
||
console.print("All documents are marked for exclusion. Nothing to analyze.", style="yellow")
|
||
continue
|
||
|
||
full_context = ""
|
||
char_count = 0
|
||
|
||
paired = sorted(filtered_pairs, key=lambda x: x[1]['source'])
|
||
|
||
for text, meta in paired:
|
||
entry = f"\n---\nSource: {Path(meta['source']).name}\n{text}\n"
|
||
if char_count + len(entry) > MAX_ANALYSIS_CONTEXT_CHARS:
|
||
full_context += "\n[...Truncated due to context limit...]"
|
||
console.print("⚠ Context limit reached, truncating analysis data.", style="yellow")
|
||
break
|
||
full_context += entry
|
||
char_count += len(entry)
|
||
|
||
context_str = full_context
|
||
chain = get_chain(SYSTEM_PROMPT_ANALYSIS)
|
||
|
||
response = ""
|
||
console.print(f"Context size: {len(context_str)} chars", style="dim")
|
||
console.print("Assistant:", style="blue", end=" ")
|
||
|
||
async for chunk in chain.astream({
|
||
"context": context_str,
|
||
"question": query,
|
||
"history": history_str
|
||
}):
|
||
console.print(chunk, end="", style=ANSWER_COLOR)
|
||
response += chunk
|
||
console.print("\n")
|
||
|
||
memory.add("user", query)
|
||
memory.add("assistant", response)
|
||
|
||
finally:
|
||
observer.stop()
|
||
observer.join()
|
||
|
||
if __name__ == "__main__":
|
||
import nest_asyncio
|
||
nest_asyncio.apply()
|
||
try:
|
||
import asyncio
|
||
loop = asyncio.get_event_loop()
|
||
loop.run_until_complete(main())
|
||
except KeyboardInterrupt:
|
||
console.print("\nGoodbye!", style="yellow")
|
||
sys.exit(0)
|