Initial commit
This commit is contained in:
3
.env.example
Normal file
3
.env.example
Normal file
@@ -0,0 +1,3 @@
|
||||
MD_FOLDER=my_docs
|
||||
EMBEDDING_MODEL=mxbai-embed-large:latest
|
||||
LLM_MODEL=qwen2.5:7b-instruct-q8_0
|
||||
15
.gitignore
vendored
Normal file
15
.gitignore
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
|
||||
#
|
||||
.env
|
||||
.cache/
|
||||
my_docs/
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.13
|
||||
22
README.md
Normal file
22
README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
# Local RAG System for Markdown Files
|
||||
|
||||
## Requirements
|
||||
|
||||
1. Install [Ollama](https://ollama.com/download/)
|
||||
2. Pull required models:
|
||||
|
||||
```bash
|
||||
ollama pull qwen2.5:7b-instruct-q8_0
|
||||
ollama pull mxbai-embed-large:latest
|
||||
```
|
||||
|
||||
You can use any model but update model names in `.env`
|
||||
|
||||
## Run
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
uv sync
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
368
main.py
Normal file
368
main.py
Normal file
@@ -0,0 +1,368 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import hashlib
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
from typing import List, Dict
|
||||
|
||||
import torch
|
||||
from dotenv import load_dotenv
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.styles import Style
|
||||
from prompt_toolkit.patch_stdout import patch_stdout
|
||||
|
||||
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_ollama import OllamaEmbeddings, ChatOllama
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
|
||||
# =========================
|
||||
# CONFIG
|
||||
# =========================
|
||||
console = Console()
|
||||
session = PromptSession()
|
||||
load_dotenv()
|
||||
|
||||
style = Style.from_dict({"prompt": "bold #6a0dad"})
|
||||
|
||||
MD_DIRECTORY = os.getenv("MD_FOLDER")
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
|
||||
LLM_MODEL = os.getenv("LLM_MODEL")
|
||||
|
||||
CHROMA_PATH = "./.cache/chroma_db"
|
||||
HASH_CACHE = "./.cache/file_hashes.json"
|
||||
|
||||
MAX_EMBED_CHARS = 380
|
||||
CHUNK_SIZE = 1200
|
||||
CHUNK_OVERLAP = 200
|
||||
TOP_K = 6
|
||||
COLLECTION_NAME = "md_rag"
|
||||
|
||||
BATCH_SIZE = 10
|
||||
MAX_PARALLEL_FILES = 3
|
||||
|
||||
# =========================
|
||||
# GPU SETUP
|
||||
# =========================
|
||||
def setup_gpu():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_per_process_memory_fraction(0.95)
|
||||
else:
|
||||
console.print("[yellow]⚠ CPU mode[/yellow]")
|
||||
|
||||
setup_gpu()
|
||||
|
||||
# =========================
|
||||
# HASH CACHE
|
||||
# =========================
|
||||
def get_file_hash(file_path: str) -> str:
|
||||
return hashlib.md5(Path(file_path).read_bytes()).hexdigest()
|
||||
|
||||
def load_hash_cache() -> dict:
|
||||
Path(HASH_CACHE).parent.mkdir(parents=True, exist_ok=True)
|
||||
if Path(HASH_CACHE).exists():
|
||||
return json.loads(Path(HASH_CACHE).read_text())
|
||||
return {}
|
||||
|
||||
def save_hash_cache(cache: dict):
|
||||
Path(HASH_CACHE).write_text(json.dumps(cache, indent=2))
|
||||
|
||||
# =========================
|
||||
# CHUNK VALIDATION
|
||||
# =========================
|
||||
def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str]:
|
||||
if len(text) <= max_chars:
|
||||
return [text]
|
||||
|
||||
sentences = text.replace('. ', '.|').replace('! ', '!|').replace('? ', '?|').split('|')
|
||||
chunks = []
|
||||
current = ""
|
||||
|
||||
for sentence in sentences:
|
||||
if len(current) + len(sentence) <= max_chars:
|
||||
current += sentence
|
||||
else:
|
||||
if current:
|
||||
chunks.append(current.strip())
|
||||
if len(sentence) > max_chars:
|
||||
words = sentence.split()
|
||||
temp = ""
|
||||
for word in words:
|
||||
if len(temp) + len(word) + 1 <= max_chars:
|
||||
temp += word + " "
|
||||
else:
|
||||
if temp:
|
||||
chunks.append(temp.strip())
|
||||
temp = word + " "
|
||||
if temp:
|
||||
chunks.append(temp.strip())
|
||||
else:
|
||||
current = sentence
|
||||
|
||||
if current:
|
||||
chunks.append(current.strip())
|
||||
|
||||
return [c for c in chunks if c]
|
||||
|
||||
# =========================
|
||||
# DOCUMENT PROCESSING
|
||||
# =========================
|
||||
class ChunkProcessor:
|
||||
def __init__(self, vectorstore):
|
||||
self.vectorstore = vectorstore
|
||||
self.semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
|
||||
|
||||
async def process_file(self, file_path: str) -> List[Dict]:
|
||||
try:
|
||||
docs = await asyncio.to_thread(
|
||||
UnstructuredMarkdownLoader(file_path).load
|
||||
)
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ {Path(file_path).name}: {e}[/red]")
|
||||
return []
|
||||
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=CHUNK_SIZE,
|
||||
chunk_overlap=CHUNK_OVERLAP,
|
||||
separators=["\n\n", "\n", ". ", " "]
|
||||
)
|
||||
|
||||
chunks = []
|
||||
for doc_idx, doc in enumerate(docs):
|
||||
for chunk_idx, text in enumerate(splitter.split_text(doc.page_content)):
|
||||
safe_texts = validate_chunk_size(text)
|
||||
for sub_idx, safe_text in enumerate(safe_texts):
|
||||
chunks.append({
|
||||
"id": f"{file_path}::{doc_idx}::{chunk_idx}::{sub_idx}",
|
||||
"text": safe_text,
|
||||
"metadata": {"source": file_path, **doc.metadata}
|
||||
})
|
||||
return chunks
|
||||
|
||||
async def embed_batch(self, batch: List[Dict]) -> bool:
|
||||
if not batch:
|
||||
return True
|
||||
|
||||
try:
|
||||
docs = [Document(page_content=c["text"], metadata=c["metadata"])
|
||||
for c in batch]
|
||||
ids = [c["id"] for c in batch]
|
||||
|
||||
await asyncio.to_thread(
|
||||
self.vectorstore.add_documents,
|
||||
docs,
|
||||
ids=ids
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e).lower()
|
||||
if "context length" in error_msg or "input length" in error_msg:
|
||||
console.print(f"[yellow]⚠ Oversized chunk detected, processing individually[/yellow]")
|
||||
for item in batch:
|
||||
try:
|
||||
doc = Document(page_content=item["text"], metadata=item["metadata"])
|
||||
await asyncio.to_thread(
|
||||
self.vectorstore.add_documents,
|
||||
[doc],
|
||||
ids=[item["id"]]
|
||||
)
|
||||
except Exception:
|
||||
console.print(f"[red]✗ Skipping chunk (too large): {len(item['text'])} chars[/red]")
|
||||
continue
|
||||
return True
|
||||
else:
|
||||
console.print(f"[red]✗ Embed error: {e}[/red]")
|
||||
return False
|
||||
|
||||
async def index_file(self, file_path: str, cache: dict) -> bool:
|
||||
async with self.semaphore:
|
||||
current_hash = get_file_hash(file_path)
|
||||
if cache.get(file_path) == current_hash:
|
||||
return False
|
||||
|
||||
chunks = await self.process_file(file_path)
|
||||
if not chunks:
|
||||
return False
|
||||
|
||||
for i in range(0, len(chunks), BATCH_SIZE):
|
||||
batch = chunks[i:i + BATCH_SIZE]
|
||||
success = await self.embed_batch(batch)
|
||||
if not success:
|
||||
console.print(f"[yellow]⚠ Partial failure in {Path(file_path).name}[/yellow]")
|
||||
|
||||
cache[file_path] = current_hash
|
||||
console.print(f"[green]✓ {Path(file_path).name} ({len(chunks)} chunks)[/green]")
|
||||
return True
|
||||
|
||||
# =========================
|
||||
# FILE WATCHER
|
||||
# =========================
|
||||
class DocumentWatcher(FileSystemEventHandler):
|
||||
def __init__(self, processor, cache):
|
||||
self.processor = processor
|
||||
self.cache = cache
|
||||
self.queue = deque()
|
||||
self.processing = False
|
||||
|
||||
def on_modified(self, event):
|
||||
if not event.is_directory and event.src_path.endswith(".md"):
|
||||
self.queue.append(event.src_path)
|
||||
|
||||
async def process_queue(self):
|
||||
while True:
|
||||
if self.queue and not self.processing:
|
||||
self.processing = True
|
||||
file_path = self.queue.popleft()
|
||||
if Path(file_path).exists():
|
||||
await self.processor.index_file(file_path, self.cache)
|
||||
save_hash_cache(self.cache)
|
||||
self.processing = False
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def start_watcher(processor, cache):
|
||||
handler = DocumentWatcher(processor, cache)
|
||||
observer = Observer()
|
||||
observer.schedule(handler, MD_DIRECTORY, recursive=True)
|
||||
observer.start()
|
||||
asyncio.create_task(handler.process_queue())
|
||||
return observer
|
||||
|
||||
# =========================
|
||||
# RAG CHAIN & MEMORY
|
||||
# =========================
|
||||
class ConversationMemory:
|
||||
def __init__(self, max_messages: int = 8):
|
||||
self.messages = []
|
||||
self.max_messages = max_messages
|
||||
|
||||
def add(self, role: str, content: str):
|
||||
self.messages.append({"role": role, "content": content})
|
||||
if len(self.messages) > self.max_messages:
|
||||
self.messages.pop(0)
|
||||
|
||||
def get_history(self) -> str:
|
||||
if not self.messages:
|
||||
return "No previous conversation."
|
||||
return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages])
|
||||
|
||||
def get_rag_components(retriever):
|
||||
llm = ChatOllama(model=LLM_MODEL, temperature=0.1)
|
||||
|
||||
# FIX 1: Added {history} to the prompt
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", "You are a precise technical assistant. Cite sources using [filename]. Be concise."),
|
||||
("human", "Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}")
|
||||
])
|
||||
|
||||
return prompt | llm | StrOutputParser()
|
||||
|
||||
# =========================
|
||||
# MAIN
|
||||
# =========================
|
||||
async def main():
|
||||
Path(MD_DIRECTORY).mkdir(parents=True, exist_ok=True)
|
||||
Path(CHROMA_PATH).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
console.print(Panel.fit(
|
||||
f"[bold cyan]⚡ RAG System[/bold cyan]\n"
|
||||
f"📂 Docs: {MD_DIRECTORY}\n"
|
||||
f"🧠 Embed: {EMBEDDING_MODEL}\n"
|
||||
f"🤖 LLM: {LLM_MODEL}",
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
embeddings = OllamaEmbeddings(model=EMBEDDING_MODEL)
|
||||
vectorstore = Chroma(
|
||||
collection_name=COLLECTION_NAME,
|
||||
persist_directory=CHROMA_PATH,
|
||||
embedding_function=embeddings
|
||||
)
|
||||
|
||||
processor = ChunkProcessor(vectorstore)
|
||||
cache = load_hash_cache()
|
||||
|
||||
console.print("\n[yellow]Indexing documents...[/yellow]")
|
||||
files = [
|
||||
os.path.join(root, file)
|
||||
for root, _, files in os.walk(MD_DIRECTORY)
|
||||
for file in files if file.endswith(".md")
|
||||
]
|
||||
|
||||
semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
|
||||
async def sem_task(fp):
|
||||
async with semaphore:
|
||||
return await processor.index_file(fp, cache)
|
||||
|
||||
tasks = [sem_task(fp) for fp in files]
|
||||
for fut in asyncio.as_completed(tasks):
|
||||
await fut
|
||||
save_hash_cache(cache)
|
||||
|
||||
console.print(f"[green]✓ Processed {len(files)} files[/green]\n")
|
||||
|
||||
observer = start_watcher(processor, cache)
|
||||
|
||||
retriever = vectorstore.as_retriever(
|
||||
search_type="similarity",
|
||||
search_kwargs={"k": TOP_K}
|
||||
)
|
||||
|
||||
rag_chain = get_rag_components(retriever)
|
||||
memory = ConversationMemory()
|
||||
|
||||
console.print("[bold green]💬 Ready![/bold green]\n")
|
||||
|
||||
try:
|
||||
with patch_stdout():
|
||||
while True:
|
||||
query = await session.prompt_async("> ", style=style)
|
||||
query = query.strip()
|
||||
if query.lower() in {"exit", "quit", "q"}:
|
||||
print("Goodbye!")
|
||||
break
|
||||
if not query:
|
||||
continue
|
||||
|
||||
docs = await asyncio.to_thread(retriever.invoke, query)
|
||||
context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs)
|
||||
history_str = memory.get_history()
|
||||
|
||||
response = ""
|
||||
async for chunk in rag_chain.astream({
|
||||
"context": context_str,
|
||||
"question": query,
|
||||
"history": history_str
|
||||
}):
|
||||
print(chunk, end="")
|
||||
response += chunk
|
||||
console.print("\n")
|
||||
|
||||
memory.add("user", query)
|
||||
memory.add("assistant", response)
|
||||
|
||||
finally:
|
||||
observer.stop()
|
||||
observer.join()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
try:
|
||||
import asyncio
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(main())
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]Goodbye![/yellow]")
|
||||
sys.exit(0)
|
||||
19
pyproject.toml
Normal file
19
pyproject.toml
Normal file
@@ -0,0 +1,19 @@
|
||||
[project]
|
||||
name = "rag-llm"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"chromadb>=1.4.0",
|
||||
"langchain-chroma>=1.1.0",
|
||||
"langchain-community>=0.4.1",
|
||||
"langchain-ollama>=1.0.1",
|
||||
"nest-asyncio>=1.6.0",
|
||||
"prompt-toolkit>=3.0.52",
|
||||
"python-dotenv>=1.2.1",
|
||||
"rich>=14.2.0",
|
||||
"torch>=2.9.1",
|
||||
"unstructured[md]>=0.18.21",
|
||||
"watchdog>=6.0.0",
|
||||
]
|
||||
Reference in New Issue
Block a user