Initial commit
This commit is contained in:
3
.env.example
Normal file
3
.env.example
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
MD_FOLDER=my_docs
|
||||||
|
EMBEDDING_MODEL=mxbai-embed-large:latest
|
||||||
|
LLM_MODEL=qwen2.5:7b-instruct-q8_0
|
||||||
15
.gitignore
vendored
Normal file
15
.gitignore
vendored
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Python-generated files
|
||||||
|
__pycache__/
|
||||||
|
*.py[oc]
|
||||||
|
build/
|
||||||
|
dist/
|
||||||
|
wheels/
|
||||||
|
*.egg-info
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv
|
||||||
|
|
||||||
|
#
|
||||||
|
.env
|
||||||
|
.cache/
|
||||||
|
my_docs/
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
|||||||
|
3.13
|
||||||
22
README.md
Normal file
22
README.md
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Local RAG System for Markdown Files
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
1. Install [Ollama](https://ollama.com/download/)
|
||||||
|
2. Pull required models:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ollama pull qwen2.5:7b-instruct-q8_0
|
||||||
|
ollama pull mxbai-embed-large:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
You can use any model but update model names in `.env`
|
||||||
|
|
||||||
|
## Run
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
uv sync
|
||||||
|
uv run main.py
|
||||||
|
```
|
||||||
|
|
||||||
368
main.py
Normal file
368
main.py
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
import asyncio
|
||||||
|
from pathlib import Path
|
||||||
|
from collections import deque
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from prompt_toolkit import PromptSession
|
||||||
|
from prompt_toolkit.styles import Style
|
||||||
|
from prompt_toolkit.patch_stdout import patch_stdout
|
||||||
|
|
||||||
|
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from langchain_ollama import OllamaEmbeddings, ChatOllama
|
||||||
|
from langchain_chroma import Chroma
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_core.output_parsers import StrOutputParser
|
||||||
|
|
||||||
|
from watchdog.observers import Observer
|
||||||
|
from watchdog.events import FileSystemEventHandler
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# CONFIG
|
||||||
|
# =========================
|
||||||
|
console = Console()
|
||||||
|
session = PromptSession()
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
style = Style.from_dict({"prompt": "bold #6a0dad"})
|
||||||
|
|
||||||
|
MD_DIRECTORY = os.getenv("MD_FOLDER")
|
||||||
|
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||||
|
LLM_MODEL = os.getenv("LLM_MODEL")
|
||||||
|
|
||||||
|
CHROMA_PATH = "./.cache/chroma_db"
|
||||||
|
HASH_CACHE = "./.cache/file_hashes.json"
|
||||||
|
|
||||||
|
MAX_EMBED_CHARS = 380
|
||||||
|
CHUNK_SIZE = 1200
|
||||||
|
CHUNK_OVERLAP = 200
|
||||||
|
TOP_K = 6
|
||||||
|
COLLECTION_NAME = "md_rag"
|
||||||
|
|
||||||
|
BATCH_SIZE = 10
|
||||||
|
MAX_PARALLEL_FILES = 3
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# GPU SETUP
|
||||||
|
# =========================
|
||||||
|
def setup_gpu():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.set_per_process_memory_fraction(0.95)
|
||||||
|
else:
|
||||||
|
console.print("[yellow]⚠ CPU mode[/yellow]")
|
||||||
|
|
||||||
|
setup_gpu()
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# HASH CACHE
|
||||||
|
# =========================
|
||||||
|
def get_file_hash(file_path: str) -> str:
|
||||||
|
return hashlib.md5(Path(file_path).read_bytes()).hexdigest()
|
||||||
|
|
||||||
|
def load_hash_cache() -> dict:
|
||||||
|
Path(HASH_CACHE).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if Path(HASH_CACHE).exists():
|
||||||
|
return json.loads(Path(HASH_CACHE).read_text())
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def save_hash_cache(cache: dict):
|
||||||
|
Path(HASH_CACHE).write_text(json.dumps(cache, indent=2))
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# CHUNK VALIDATION
|
||||||
|
# =========================
|
||||||
|
def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str]:
|
||||||
|
if len(text) <= max_chars:
|
||||||
|
return [text]
|
||||||
|
|
||||||
|
sentences = text.replace('. ', '.|').replace('! ', '!|').replace('? ', '?|').split('|')
|
||||||
|
chunks = []
|
||||||
|
current = ""
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
if len(current) + len(sentence) <= max_chars:
|
||||||
|
current += sentence
|
||||||
|
else:
|
||||||
|
if current:
|
||||||
|
chunks.append(current.strip())
|
||||||
|
if len(sentence) > max_chars:
|
||||||
|
words = sentence.split()
|
||||||
|
temp = ""
|
||||||
|
for word in words:
|
||||||
|
if len(temp) + len(word) + 1 <= max_chars:
|
||||||
|
temp += word + " "
|
||||||
|
else:
|
||||||
|
if temp:
|
||||||
|
chunks.append(temp.strip())
|
||||||
|
temp = word + " "
|
||||||
|
if temp:
|
||||||
|
chunks.append(temp.strip())
|
||||||
|
else:
|
||||||
|
current = sentence
|
||||||
|
|
||||||
|
if current:
|
||||||
|
chunks.append(current.strip())
|
||||||
|
|
||||||
|
return [c for c in chunks if c]
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# DOCUMENT PROCESSING
|
||||||
|
# =========================
|
||||||
|
class ChunkProcessor:
|
||||||
|
def __init__(self, vectorstore):
|
||||||
|
self.vectorstore = vectorstore
|
||||||
|
self.semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
|
||||||
|
|
||||||
|
async def process_file(self, file_path: str) -> List[Dict]:
|
||||||
|
try:
|
||||||
|
docs = await asyncio.to_thread(
|
||||||
|
UnstructuredMarkdownLoader(file_path).load
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ {Path(file_path).name}: {e}[/red]")
|
||||||
|
return []
|
||||||
|
|
||||||
|
splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
chunk_overlap=CHUNK_OVERLAP,
|
||||||
|
separators=["\n\n", "\n", ". ", " "]
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
for doc_idx, doc in enumerate(docs):
|
||||||
|
for chunk_idx, text in enumerate(splitter.split_text(doc.page_content)):
|
||||||
|
safe_texts = validate_chunk_size(text)
|
||||||
|
for sub_idx, safe_text in enumerate(safe_texts):
|
||||||
|
chunks.append({
|
||||||
|
"id": f"{file_path}::{doc_idx}::{chunk_idx}::{sub_idx}",
|
||||||
|
"text": safe_text,
|
||||||
|
"metadata": {"source": file_path, **doc.metadata}
|
||||||
|
})
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
async def embed_batch(self, batch: List[Dict]) -> bool:
|
||||||
|
if not batch:
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
docs = [Document(page_content=c["text"], metadata=c["metadata"])
|
||||||
|
for c in batch]
|
||||||
|
ids = [c["id"] for c in batch]
|
||||||
|
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.vectorstore.add_documents,
|
||||||
|
docs,
|
||||||
|
ids=ids
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e).lower()
|
||||||
|
if "context length" in error_msg or "input length" in error_msg:
|
||||||
|
console.print(f"[yellow]⚠ Oversized chunk detected, processing individually[/yellow]")
|
||||||
|
for item in batch:
|
||||||
|
try:
|
||||||
|
doc = Document(page_content=item["text"], metadata=item["metadata"])
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.vectorstore.add_documents,
|
||||||
|
[doc],
|
||||||
|
ids=[item["id"]]
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
console.print(f"[red]✗ Skipping chunk (too large): {len(item['text'])} chars[/red]")
|
||||||
|
continue
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
console.print(f"[red]✗ Embed error: {e}[/red]")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def index_file(self, file_path: str, cache: dict) -> bool:
|
||||||
|
async with self.semaphore:
|
||||||
|
current_hash = get_file_hash(file_path)
|
||||||
|
if cache.get(file_path) == current_hash:
|
||||||
|
return False
|
||||||
|
|
||||||
|
chunks = await self.process_file(file_path)
|
||||||
|
if not chunks:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for i in range(0, len(chunks), BATCH_SIZE):
|
||||||
|
batch = chunks[i:i + BATCH_SIZE]
|
||||||
|
success = await self.embed_batch(batch)
|
||||||
|
if not success:
|
||||||
|
console.print(f"[yellow]⚠ Partial failure in {Path(file_path).name}[/yellow]")
|
||||||
|
|
||||||
|
cache[file_path] = current_hash
|
||||||
|
console.print(f"[green]✓ {Path(file_path).name} ({len(chunks)} chunks)[/green]")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# FILE WATCHER
|
||||||
|
# =========================
|
||||||
|
class DocumentWatcher(FileSystemEventHandler):
|
||||||
|
def __init__(self, processor, cache):
|
||||||
|
self.processor = processor
|
||||||
|
self.cache = cache
|
||||||
|
self.queue = deque()
|
||||||
|
self.processing = False
|
||||||
|
|
||||||
|
def on_modified(self, event):
|
||||||
|
if not event.is_directory and event.src_path.endswith(".md"):
|
||||||
|
self.queue.append(event.src_path)
|
||||||
|
|
||||||
|
async def process_queue(self):
|
||||||
|
while True:
|
||||||
|
if self.queue and not self.processing:
|
||||||
|
self.processing = True
|
||||||
|
file_path = self.queue.popleft()
|
||||||
|
if Path(file_path).exists():
|
||||||
|
await self.processor.index_file(file_path, self.cache)
|
||||||
|
save_hash_cache(self.cache)
|
||||||
|
self.processing = False
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
def start_watcher(processor, cache):
|
||||||
|
handler = DocumentWatcher(processor, cache)
|
||||||
|
observer = Observer()
|
||||||
|
observer.schedule(handler, MD_DIRECTORY, recursive=True)
|
||||||
|
observer.start()
|
||||||
|
asyncio.create_task(handler.process_queue())
|
||||||
|
return observer
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# RAG CHAIN & MEMORY
|
||||||
|
# =========================
|
||||||
|
class ConversationMemory:
|
||||||
|
def __init__(self, max_messages: int = 8):
|
||||||
|
self.messages = []
|
||||||
|
self.max_messages = max_messages
|
||||||
|
|
||||||
|
def add(self, role: str, content: str):
|
||||||
|
self.messages.append({"role": role, "content": content})
|
||||||
|
if len(self.messages) > self.max_messages:
|
||||||
|
self.messages.pop(0)
|
||||||
|
|
||||||
|
def get_history(self) -> str:
|
||||||
|
if not self.messages:
|
||||||
|
return "No previous conversation."
|
||||||
|
return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages])
|
||||||
|
|
||||||
|
def get_rag_components(retriever):
|
||||||
|
llm = ChatOllama(model=LLM_MODEL, temperature=0.1)
|
||||||
|
|
||||||
|
# FIX 1: Added {history} to the prompt
|
||||||
|
prompt = ChatPromptTemplate.from_messages([
|
||||||
|
("system", "You are a precise technical assistant. Cite sources using [filename]. Be concise."),
|
||||||
|
("human", "Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}")
|
||||||
|
])
|
||||||
|
|
||||||
|
return prompt | llm | StrOutputParser()
|
||||||
|
|
||||||
|
# =========================
|
||||||
|
# MAIN
|
||||||
|
# =========================
|
||||||
|
async def main():
|
||||||
|
Path(MD_DIRECTORY).mkdir(parents=True, exist_ok=True)
|
||||||
|
Path(CHROMA_PATH).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
console.print(Panel.fit(
|
||||||
|
f"[bold cyan]⚡ RAG System[/bold cyan]\n"
|
||||||
|
f"📂 Docs: {MD_DIRECTORY}\n"
|
||||||
|
f"🧠 Embed: {EMBEDDING_MODEL}\n"
|
||||||
|
f"🤖 LLM: {LLM_MODEL}",
|
||||||
|
border_style="cyan"
|
||||||
|
))
|
||||||
|
|
||||||
|
embeddings = OllamaEmbeddings(model=EMBEDDING_MODEL)
|
||||||
|
vectorstore = Chroma(
|
||||||
|
collection_name=COLLECTION_NAME,
|
||||||
|
persist_directory=CHROMA_PATH,
|
||||||
|
embedding_function=embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
processor = ChunkProcessor(vectorstore)
|
||||||
|
cache = load_hash_cache()
|
||||||
|
|
||||||
|
console.print("\n[yellow]Indexing documents...[/yellow]")
|
||||||
|
files = [
|
||||||
|
os.path.join(root, file)
|
||||||
|
for root, _, files in os.walk(MD_DIRECTORY)
|
||||||
|
for file in files if file.endswith(".md")
|
||||||
|
]
|
||||||
|
|
||||||
|
semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
|
||||||
|
async def sem_task(fp):
|
||||||
|
async with semaphore:
|
||||||
|
return await processor.index_file(fp, cache)
|
||||||
|
|
||||||
|
tasks = [sem_task(fp) for fp in files]
|
||||||
|
for fut in asyncio.as_completed(tasks):
|
||||||
|
await fut
|
||||||
|
save_hash_cache(cache)
|
||||||
|
|
||||||
|
console.print(f"[green]✓ Processed {len(files)} files[/green]\n")
|
||||||
|
|
||||||
|
observer = start_watcher(processor, cache)
|
||||||
|
|
||||||
|
retriever = vectorstore.as_retriever(
|
||||||
|
search_type="similarity",
|
||||||
|
search_kwargs={"k": TOP_K}
|
||||||
|
)
|
||||||
|
|
||||||
|
rag_chain = get_rag_components(retriever)
|
||||||
|
memory = ConversationMemory()
|
||||||
|
|
||||||
|
console.print("[bold green]💬 Ready![/bold green]\n")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch_stdout():
|
||||||
|
while True:
|
||||||
|
query = await session.prompt_async("> ", style=style)
|
||||||
|
query = query.strip()
|
||||||
|
if query.lower() in {"exit", "quit", "q"}:
|
||||||
|
print("Goodbye!")
|
||||||
|
break
|
||||||
|
if not query:
|
||||||
|
continue
|
||||||
|
|
||||||
|
docs = await asyncio.to_thread(retriever.invoke, query)
|
||||||
|
context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs)
|
||||||
|
history_str = memory.get_history()
|
||||||
|
|
||||||
|
response = ""
|
||||||
|
async for chunk in rag_chain.astream({
|
||||||
|
"context": context_str,
|
||||||
|
"question": query,
|
||||||
|
"history": history_str
|
||||||
|
}):
|
||||||
|
print(chunk, end="")
|
||||||
|
response += chunk
|
||||||
|
console.print("\n")
|
||||||
|
|
||||||
|
memory.add("user", query)
|
||||||
|
memory.add("assistant", response)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
observer.stop()
|
||||||
|
observer.join()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import nest_asyncio
|
||||||
|
nest_asyncio.apply()
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
console.print("\n[yellow]Goodbye![/yellow]")
|
||||||
|
sys.exit(0)
|
||||||
19
pyproject.toml
Normal file
19
pyproject.toml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
[project]
|
||||||
|
name = "rag-llm"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Add your description here"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.13"
|
||||||
|
dependencies = [
|
||||||
|
"chromadb>=1.4.0",
|
||||||
|
"langchain-chroma>=1.1.0",
|
||||||
|
"langchain-community>=0.4.1",
|
||||||
|
"langchain-ollama>=1.0.1",
|
||||||
|
"nest-asyncio>=1.6.0",
|
||||||
|
"prompt-toolkit>=3.0.52",
|
||||||
|
"python-dotenv>=1.2.1",
|
||||||
|
"rich>=14.2.0",
|
||||||
|
"torch>=2.9.1",
|
||||||
|
"unstructured[md]>=0.18.21",
|
||||||
|
"watchdog>=6.0.0",
|
||||||
|
]
|
||||||
Reference in New Issue
Block a user