fix: colors and other
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
MD_FOLDER=my_docs
|
||||
MD_FOLDER=notes
|
||||
EMBEDDING_MODEL=mxbai-embed-large:latest
|
||||
LLM_MODEL=qwen2.5:7b-instruct-q8_0
|
||||
OLLAMA_BASE_URL=http://localhost:11434
|
||||
|
||||
SYSTEM_PROMPT="You are a precise technical assistant. Cite sources using [filename]. Be concise."
|
||||
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -12,4 +12,4 @@ wheels/
|
||||
#
|
||||
.env
|
||||
.cache/
|
||||
my_docs/
|
||||
notes/
|
||||
|
||||
165
main.py
165
main.py
@@ -7,16 +7,13 @@ import asyncio
|
||||
import re
|
||||
from pathlib import Path
|
||||
from collections import deque
|
||||
from typing import List, Dict, Tuple
|
||||
from typing import List, Dict
|
||||
|
||||
import torch
|
||||
from dotenv import load_dotenv
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.markdown import Markdown
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.styles import Style
|
||||
from prompt_toolkit.patch_stdout import patch_stdout
|
||||
|
||||
from langchain_community.document_loaders import UnstructuredMarkdownLoader
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
@@ -32,13 +29,14 @@ from watchdog.events import FileSystemEventHandler
|
||||
# =========================
|
||||
# CONFIG
|
||||
# =========================
|
||||
console = Console()
|
||||
console = Console(color_system="standard", force_terminal=True)
|
||||
session = PromptSession()
|
||||
load_dotenv()
|
||||
|
||||
style = Style.from_dict({"prompt": "bold #6a0dad"})
|
||||
|
||||
# --- PROMPTS ---
|
||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
||||
|
||||
SYSTEM_PROMPT_SEARCH = os.getenv("SYSTEM_PROMPT", "You are a precise technical assistant. Cite sources using [filename]. Be concise.")
|
||||
SYSTEM_PROMPT_ANALYSIS = (
|
||||
"You are an expert tutor and progress evaluator. "
|
||||
@@ -51,7 +49,7 @@ SYSTEM_PROMPT_ANALYSIS = (
|
||||
USER_PROMPT_TEMPLATE = os.getenv("USER_PROMPT_TEMPLATE",
|
||||
"Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}")
|
||||
|
||||
MD_DIRECTORY = os.getenv("MD_FOLDER", "./notes")
|
||||
MD_DIRECTORY = os.getenv("MD_FOLDER", "./my_docs")
|
||||
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nomic-embed-text")
|
||||
LLM_MODEL = os.getenv("LLM_MODEL", "llama3")
|
||||
|
||||
@@ -64,28 +62,11 @@ CHUNK_OVERLAP = 200
|
||||
TOP_K = 6
|
||||
COLLECTION_NAME = "md_rag"
|
||||
|
||||
# Limit context size for Analysis mode (approx 24k chars ~ 6k tokens) to prevent OOM
|
||||
MAX_ANALYSIS_CONTEXT_CHARS = 24000
|
||||
|
||||
BATCH_SIZE = 10
|
||||
MAX_PARALLEL_FILES = 3
|
||||
|
||||
# =========================
|
||||
# GPU SETUP
|
||||
# =========================
|
||||
def setup_gpu():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_per_process_memory_fraction(0.95)
|
||||
device_id = torch.cuda.current_device()
|
||||
device_name = torch.cuda.get_device_name(device_id)
|
||||
total_vram = torch.cuda.get_device_properties(device_id).total_memory / (1024**3)
|
||||
console.print(f"[green]✓ GPU: {device_name} ({total_vram:.1f}GB)[/green]\n")
|
||||
else:
|
||||
console.print("[yellow]⚠ CPU mode[/yellow]\n")
|
||||
console.print("\n")
|
||||
|
||||
setup_gpu()
|
||||
|
||||
# =========================
|
||||
# UTILS & CACHE
|
||||
# =========================
|
||||
@@ -105,14 +86,12 @@ def save_hash_cache(cache: dict):
|
||||
# ROUTING LOGIC
|
||||
# =========================
|
||||
def classify_intent(query: str) -> str:
|
||||
"""
|
||||
Determines if the user wants a specific search (RAG) or a global assessment.
|
||||
"""
|
||||
analysis_keywords = [
|
||||
r"assess my progress", r"eval(uate)? my (learning|knowledge)",
|
||||
r"what have i learned", r"summary of (my )?notes",
|
||||
r"my progress", r"learning path", r"knowledge gap",
|
||||
r"оцени (мой )?прогресс", r"что я выучил", r"итоги", r"анализ знаний"
|
||||
r"оцени (мой )?прогресс", r"что я выучил", r"итоги", r"анализ знаний",
|
||||
r"сегодня урок", r"что я изучил"
|
||||
]
|
||||
|
||||
query_lower = query.lower()
|
||||
@@ -137,7 +116,6 @@ def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str
|
||||
current += sentence
|
||||
else:
|
||||
if current: chunks.append(current.strip())
|
||||
# Handle extremely long sentences by word splitting
|
||||
if len(sentence) > max_chars:
|
||||
words = sentence.split()
|
||||
temp = ""
|
||||
@@ -164,7 +142,7 @@ class ChunkProcessor:
|
||||
try:
|
||||
docs = await asyncio.to_thread(UnstructuredMarkdownLoader(file_path).load)
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ {Path(file_path).name}: {e}[/red]")
|
||||
console.print(f"✗ {Path(file_path).name}: {e}", style="red")
|
||||
return []
|
||||
|
||||
splitter = RecursiveCharacterTextSplitter(
|
||||
@@ -193,7 +171,7 @@ class ChunkProcessor:
|
||||
await asyncio.to_thread(self.vectorstore.add_documents, docs, ids=ids)
|
||||
return True
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Embed error: {e}[/red]")
|
||||
console.print(f"✗ Embed error: {e}", style="red")
|
||||
return False
|
||||
|
||||
async def index_file(self, file_path: str, cache: dict) -> bool:
|
||||
@@ -208,14 +186,14 @@ class ChunkProcessor:
|
||||
try:
|
||||
self.vectorstore._collection.delete(where={"source": file_path})
|
||||
except:
|
||||
pass # Collection might be empty
|
||||
pass
|
||||
|
||||
for i in range(0, len(chunks), BATCH_SIZE):
|
||||
batch = chunks[i:i + BATCH_SIZE]
|
||||
await self.embed_batch(batch)
|
||||
|
||||
cache[file_path] = current_hash
|
||||
console.print(f"[green]✓ {Path(file_path).name} ({len(chunks)} chunks)[/green]")
|
||||
console.print(f"✓ {Path(file_path).name} ({len(chunks)} chunks)", style="green")
|
||||
return True
|
||||
|
||||
# =========================
|
||||
@@ -269,7 +247,11 @@ class ConversationMemory:
|
||||
return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages])
|
||||
|
||||
def get_chain(system_prompt):
|
||||
llm = ChatOllama(model=LLM_MODEL, temperature=0.2)
|
||||
llm = ChatOllama(
|
||||
model=LLM_MODEL,
|
||||
temperature=0.2,
|
||||
base_url=OLLAMA_BASE_URL
|
||||
)
|
||||
prompt = ChatPromptTemplate.from_messages([
|
||||
("system", system_prompt),
|
||||
("human", USER_PROMPT_TEMPLATE)
|
||||
@@ -291,7 +273,10 @@ async def main():
|
||||
border_style="cyan"
|
||||
))
|
||||
|
||||
embeddings = OllamaEmbeddings(model=EMBEDDING_MODEL)
|
||||
embeddings = OllamaEmbeddings(
|
||||
model=EMBEDDING_MODEL,
|
||||
base_url=OLLAMA_BASE_URL
|
||||
)
|
||||
vectorstore = Chroma(
|
||||
collection_name=COLLECTION_NAME,
|
||||
persist_directory=CHROMA_PATH,
|
||||
@@ -301,14 +286,13 @@ async def main():
|
||||
processor = ChunkProcessor(vectorstore)
|
||||
cache = load_hash_cache()
|
||||
|
||||
console.print("\n[yellow]Checking documents...[/yellow]")
|
||||
console.print("Checking documents...", style="yellow")
|
||||
files = [
|
||||
os.path.join(root, file)
|
||||
for root, _, files in os.walk(MD_DIRECTORY)
|
||||
for file in files if file.endswith(".md")
|
||||
]
|
||||
|
||||
# Initial Indexing
|
||||
semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
|
||||
async def sem_task(fp):
|
||||
async with semaphore:
|
||||
@@ -322,77 +306,72 @@ async def main():
|
||||
observer = start_watcher(processor, cache)
|
||||
memory = ConversationMemory()
|
||||
|
||||
console.print("[bold green]💬 Ready! Type 'exit' to quit.[/bold green]\n")
|
||||
console.print("💬 Ready! Type 'exit' to quit.", style="bold green")
|
||||
|
||||
try:
|
||||
with patch_stdout():
|
||||
while True:
|
||||
query = await session.prompt_async("> ", style=style)
|
||||
query = query.strip()
|
||||
if query.lower() in {"exit", "quit", "q"}:
|
||||
print("Goodbye!")
|
||||
break
|
||||
if not query: continue
|
||||
while True:
|
||||
query = await session.prompt_async("> ", style=style)
|
||||
query = query.strip()
|
||||
if query.lower() in {"exit", "quit", "q"}:
|
||||
console.print("Goodbye!", style="yellow")
|
||||
break
|
||||
if not query: continue
|
||||
|
||||
mode = classify_intent(query)
|
||||
history_str = memory.get_history()
|
||||
mode = classify_intent(query)
|
||||
history_str = memory.get_history()
|
||||
|
||||
if mode == "SEARCH":
|
||||
console.print("[bold blue]🔍 SEARCH MODE (Top-K)[/bold blue]")
|
||||
if mode == "SEARCH":
|
||||
console.print("🔍 SEARCH MODE (Top-K)", style="bold blue")
|
||||
|
||||
# Standard RAG
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K})
|
||||
docs = await asyncio.to_thread(retriever.invoke, query)
|
||||
context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs)
|
||||
retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K})
|
||||
docs = await asyncio.to_thread(retriever.invoke, query)
|
||||
context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs)
|
||||
|
||||
chain = get_chain(SYSTEM_PROMPT_SEARCH)
|
||||
chain = get_chain(SYSTEM_PROMPT_SEARCH)
|
||||
|
||||
else: # ANALYSIS MODE
|
||||
console.print("[bold magenta]📊 ANALYSIS MODE (Full Context)[/bold magenta]")
|
||||
else: # ANALYSIS MODE
|
||||
console.print("📊 ANALYSIS MODE (Full Context)", style="bold magenta")
|
||||
|
||||
# Fetch ALL documents (limited by size)
|
||||
# Chroma .get() returns dict with keys: ids, embeddings, documents, metadatas
|
||||
db_data = await asyncio.to_thread(vectorstore.get)
|
||||
all_texts = db_data['documents']
|
||||
all_metas = db_data['metadatas']
|
||||
db_data = await asyncio.to_thread(vectorstore.get)
|
||||
all_texts = db_data['documents']
|
||||
all_metas = db_data['metadatas']
|
||||
|
||||
if not all_texts:
|
||||
console.print("[red]No documents found to analyze![/red]")
|
||||
continue
|
||||
if not all_texts:
|
||||
console.print("No documents found to analyze!", style="red")
|
||||
continue
|
||||
|
||||
# Concatenate content for analysis
|
||||
full_context = ""
|
||||
char_count = 0
|
||||
full_context = ""
|
||||
char_count = 0
|
||||
|
||||
# Sort arbitrarily or by source to group files
|
||||
paired = sorted(zip(all_texts, all_metas), key=lambda x: x[1]['source'])
|
||||
paired = sorted(zip(all_texts, all_metas), key=lambda x: x[1]['source'])
|
||||
|
||||
for text, meta in paired:
|
||||
entry = f"\n---\nSource: {Path(meta['source']).name}\n{text}\n"
|
||||
if char_count + len(entry) > MAX_ANALYSIS_CONTEXT_CHARS:
|
||||
full_context += "\n[...Truncated due to context limit...]"
|
||||
console.print("[yellow]⚠ Context limit reached, truncating analysis data.[/yellow]")
|
||||
break
|
||||
full_context += entry
|
||||
char_count += len(entry)
|
||||
for text, meta in paired:
|
||||
entry = f"\n---\nSource: {Path(meta['source']).name}\n{text}\n"
|
||||
if char_count + len(entry) > MAX_ANALYSIS_CONTEXT_CHARS:
|
||||
full_context += "\n[...Truncated due to context limit...]"
|
||||
console.print("⚠ Context limit reached, truncating analysis data.", style="yellow")
|
||||
break
|
||||
full_context += entry
|
||||
char_count += len(entry)
|
||||
|
||||
context_str = full_context
|
||||
chain = get_chain(SYSTEM_PROMPT_ANALYSIS)
|
||||
context_str = full_context
|
||||
chain = get_chain(SYSTEM_PROMPT_ANALYSIS)
|
||||
|
||||
response = ""
|
||||
console.print(f"[dim]Context size: {len(context_str)} chars[/dim]")
|
||||
response = ""
|
||||
console.print(f"Context size: {len(context_str)} chars", style="dim")
|
||||
console.print("Assistant:", style="blue", end=" ")
|
||||
|
||||
async for chunk in chain.astream({
|
||||
"context": context_str,
|
||||
"question": query,
|
||||
"history": history_str
|
||||
}):
|
||||
print(chunk, end="")
|
||||
response += chunk
|
||||
console.print("\n")
|
||||
async for chunk in chain.astream({
|
||||
"context": context_str,
|
||||
"question": query,
|
||||
"history": history_str
|
||||
}):
|
||||
print(chunk, end="")
|
||||
response += chunk
|
||||
console.print("\n")
|
||||
|
||||
memory.add("user", query)
|
||||
memory.add("assistant", response)
|
||||
memory.add("user", query)
|
||||
memory.add("assistant", response)
|
||||
|
||||
finally:
|
||||
observer.stop()
|
||||
@@ -406,5 +385,5 @@ if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(main())
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]Goodbye![/yellow]")
|
||||
console.print("Goodbye!", style="yellow")
|
||||
sys.exit(0)
|
||||
|
||||
Reference in New Issue
Block a user