fix: colors and other

This commit is contained in:
2025-12-30 06:28:40 +03:00
parent 19d4fe09b6
commit 592e393f06
3 changed files with 85 additions and 105 deletions

View File

@@ -1,6 +1,7 @@
MD_FOLDER=my_docs MD_FOLDER=notes
EMBEDDING_MODEL=mxbai-embed-large:latest EMBEDDING_MODEL=mxbai-embed-large:latest
LLM_MODEL=qwen2.5:7b-instruct-q8_0 LLM_MODEL=qwen2.5:7b-instruct-q8_0
OLLAMA_BASE_URL=http://localhost:11434
SYSTEM_PROMPT="You are a precise technical assistant. Cite sources using [filename]. Be concise." SYSTEM_PROMPT="You are a precise technical assistant. Cite sources using [filename]. Be concise."

2
.gitignore vendored
View File

@@ -12,4 +12,4 @@ wheels/
# #
.env .env
.cache/ .cache/
my_docs/ notes/

185
main.py
View File

@@ -7,16 +7,13 @@ import asyncio
import re import re
from pathlib import Path from pathlib import Path
from collections import deque from collections import deque
from typing import List, Dict, Tuple from typing import List, Dict
import torch
from dotenv import load_dotenv from dotenv import load_dotenv
from rich.console import Console from rich.console import Console
from rich.panel import Panel from rich.panel import Panel
from rich.markdown import Markdown
from prompt_toolkit import PromptSession from prompt_toolkit import PromptSession
from prompt_toolkit.styles import Style from prompt_toolkit.styles import Style
from prompt_toolkit.patch_stdout import patch_stdout
from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_text_splitters import RecursiveCharacterTextSplitter
@@ -32,13 +29,14 @@ from watchdog.events import FileSystemEventHandler
# ========================= # =========================
# CONFIG # CONFIG
# ========================= # =========================
console = Console() console = Console(color_system="standard", force_terminal=True)
session = PromptSession() session = PromptSession()
load_dotenv() load_dotenv()
style = Style.from_dict({"prompt": "bold #6a0dad"}) style = Style.from_dict({"prompt": "bold #6a0dad"})
# --- PROMPTS --- OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
SYSTEM_PROMPT_SEARCH = os.getenv("SYSTEM_PROMPT", "You are a precise technical assistant. Cite sources using [filename]. Be concise.") SYSTEM_PROMPT_SEARCH = os.getenv("SYSTEM_PROMPT", "You are a precise technical assistant. Cite sources using [filename]. Be concise.")
SYSTEM_PROMPT_ANALYSIS = ( SYSTEM_PROMPT_ANALYSIS = (
"You are an expert tutor and progress evaluator. " "You are an expert tutor and progress evaluator. "
@@ -51,7 +49,7 @@ SYSTEM_PROMPT_ANALYSIS = (
USER_PROMPT_TEMPLATE = os.getenv("USER_PROMPT_TEMPLATE", USER_PROMPT_TEMPLATE = os.getenv("USER_PROMPT_TEMPLATE",
"Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}") "Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}")
MD_DIRECTORY = os.getenv("MD_FOLDER", "./notes") MD_DIRECTORY = os.getenv("MD_FOLDER", "./my_docs")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nomic-embed-text") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nomic-embed-text")
LLM_MODEL = os.getenv("LLM_MODEL", "llama3") LLM_MODEL = os.getenv("LLM_MODEL", "llama3")
@@ -64,28 +62,11 @@ CHUNK_OVERLAP = 200
TOP_K = 6 TOP_K = 6
COLLECTION_NAME = "md_rag" COLLECTION_NAME = "md_rag"
# Limit context size for Analysis mode (approx 24k chars ~ 6k tokens) to prevent OOM
MAX_ANALYSIS_CONTEXT_CHARS = 24000 MAX_ANALYSIS_CONTEXT_CHARS = 24000
BATCH_SIZE = 10 BATCH_SIZE = 10
MAX_PARALLEL_FILES = 3 MAX_PARALLEL_FILES = 3
# =========================
# GPU SETUP
# =========================
def setup_gpu():
if torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(0.95)
device_id = torch.cuda.current_device()
device_name = torch.cuda.get_device_name(device_id)
total_vram = torch.cuda.get_device_properties(device_id).total_memory / (1024**3)
console.print(f"[green]✓ GPU: {device_name} ({total_vram:.1f}GB)[/green]\n")
else:
console.print("[yellow]⚠ CPU mode[/yellow]\n")
console.print("\n")
setup_gpu()
# ========================= # =========================
# UTILS & CACHE # UTILS & CACHE
# ========================= # =========================
@@ -105,14 +86,12 @@ def save_hash_cache(cache: dict):
# ROUTING LOGIC # ROUTING LOGIC
# ========================= # =========================
def classify_intent(query: str) -> str: def classify_intent(query: str) -> str:
"""
Determines if the user wants a specific search (RAG) or a global assessment.
"""
analysis_keywords = [ analysis_keywords = [
r"assess my progress", r"eval(uate)? my (learning|knowledge)", r"assess my progress", r"eval(uate)? my (learning|knowledge)",
r"what have i learned", r"summary of (my )?notes", r"what have i learned", r"summary of (my )?notes",
r"my progress", r"learning path", r"knowledge gap", r"my progress", r"learning path", r"knowledge gap",
r"оцени (мой )?прогресс", r"что я выучил", r"итоги", r"анализ знаний" r"оцени (мой )?прогресс", r"что я выучил", r"итоги", r"анализ знаний",
r"сегодня урок", r"что я изучил"
] ]
query_lower = query.lower() query_lower = query.lower()
@@ -137,7 +116,6 @@ def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str
current += sentence current += sentence
else: else:
if current: chunks.append(current.strip()) if current: chunks.append(current.strip())
# Handle extremely long sentences by word splitting
if len(sentence) > max_chars: if len(sentence) > max_chars:
words = sentence.split() words = sentence.split()
temp = "" temp = ""
@@ -164,7 +142,7 @@ class ChunkProcessor:
try: try:
docs = await asyncio.to_thread(UnstructuredMarkdownLoader(file_path).load) docs = await asyncio.to_thread(UnstructuredMarkdownLoader(file_path).load)
except Exception as e: except Exception as e:
console.print(f"[red]{Path(file_path).name}: {e}[/red]") console.print(f"{Path(file_path).name}: {e}", style="red")
return [] return []
splitter = RecursiveCharacterTextSplitter( splitter = RecursiveCharacterTextSplitter(
@@ -193,7 +171,7 @@ class ChunkProcessor:
await asyncio.to_thread(self.vectorstore.add_documents, docs, ids=ids) await asyncio.to_thread(self.vectorstore.add_documents, docs, ids=ids)
return True return True
except Exception as e: except Exception as e:
console.print(f"[red]✗ Embed error: {e}[/red]") console.print(f"✗ Embed error: {e}", style="red")
return False return False
async def index_file(self, file_path: str, cache: dict) -> bool: async def index_file(self, file_path: str, cache: dict) -> bool:
@@ -208,14 +186,14 @@ class ChunkProcessor:
try: try:
self.vectorstore._collection.delete(where={"source": file_path}) self.vectorstore._collection.delete(where={"source": file_path})
except: except:
pass # Collection might be empty pass
for i in range(0, len(chunks), BATCH_SIZE): for i in range(0, len(chunks), BATCH_SIZE):
batch = chunks[i:i + BATCH_SIZE] batch = chunks[i:i + BATCH_SIZE]
await self.embed_batch(batch) await self.embed_batch(batch)
cache[file_path] = current_hash cache[file_path] = current_hash
console.print(f"[green]{Path(file_path).name} ({len(chunks)} chunks)[/green]") console.print(f"{Path(file_path).name} ({len(chunks)} chunks)", style="green")
return True return True
# ========================= # =========================
@@ -269,7 +247,11 @@ class ConversationMemory:
return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages]) return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages])
def get_chain(system_prompt): def get_chain(system_prompt):
llm = ChatOllama(model=LLM_MODEL, temperature=0.2) llm = ChatOllama(
model=LLM_MODEL,
temperature=0.2,
base_url=OLLAMA_BASE_URL
)
prompt = ChatPromptTemplate.from_messages([ prompt = ChatPromptTemplate.from_messages([
("system", system_prompt), ("system", system_prompt),
("human", USER_PROMPT_TEMPLATE) ("human", USER_PROMPT_TEMPLATE)
@@ -291,7 +273,10 @@ async def main():
border_style="cyan" border_style="cyan"
)) ))
embeddings = OllamaEmbeddings(model=EMBEDDING_MODEL) embeddings = OllamaEmbeddings(
model=EMBEDDING_MODEL,
base_url=OLLAMA_BASE_URL
)
vectorstore = Chroma( vectorstore = Chroma(
collection_name=COLLECTION_NAME, collection_name=COLLECTION_NAME,
persist_directory=CHROMA_PATH, persist_directory=CHROMA_PATH,
@@ -301,14 +286,13 @@ async def main():
processor = ChunkProcessor(vectorstore) processor = ChunkProcessor(vectorstore)
cache = load_hash_cache() cache = load_hash_cache()
console.print("\n[yellow]Checking documents...[/yellow]") console.print("Checking documents...", style="yellow")
files = [ files = [
os.path.join(root, file) os.path.join(root, file)
for root, _, files in os.walk(MD_DIRECTORY) for root, _, files in os.walk(MD_DIRECTORY)
for file in files if file.endswith(".md") for file in files if file.endswith(".md")
] ]
# Initial Indexing
semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES) semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES)
async def sem_task(fp): async def sem_task(fp):
async with semaphore: async with semaphore:
@@ -322,77 +306,72 @@ async def main():
observer = start_watcher(processor, cache) observer = start_watcher(processor, cache)
memory = ConversationMemory() memory = ConversationMemory()
console.print("[bold green]💬 Ready! Type 'exit' to quit.[/bold green]\n") console.print("💬 Ready! Type 'exit' to quit.", style="bold green")
try: try:
with patch_stdout(): while True:
while True: query = await session.prompt_async("> ", style=style)
query = await session.prompt_async("> ", style=style) query = query.strip()
query = query.strip() if query.lower() in {"exit", "quit", "q"}:
if query.lower() in {"exit", "quit", "q"}: console.print("Goodbye!", style="yellow")
print("Goodbye!") break
break if not query: continue
if not query: continue
mode = classify_intent(query) mode = classify_intent(query)
history_str = memory.get_history() history_str = memory.get_history()
if mode == "SEARCH": if mode == "SEARCH":
console.print("[bold blue]🔍 SEARCH MODE (Top-K)[/bold blue]") console.print("🔍 SEARCH MODE (Top-K)", style="bold blue")
# Standard RAG
retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K})
docs = await asyncio.to_thread(retriever.invoke, query)
context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs)
chain = get_chain(SYSTEM_PROMPT_SEARCH)
else: # ANALYSIS MODE
console.print("[bold magenta]📊 ANALYSIS MODE (Full Context)[/bold magenta]")
# Fetch ALL documents (limited by size)
# Chroma .get() returns dict with keys: ids, embeddings, documents, metadatas
db_data = await asyncio.to_thread(vectorstore.get)
all_texts = db_data['documents']
all_metas = db_data['metadatas']
if not all_texts:
console.print("[red]No documents found to analyze![/red]")
continue
# Concatenate content for analysis
full_context = ""
char_count = 0
# Sort arbitrarily or by source to group files
paired = sorted(zip(all_texts, all_metas), key=lambda x: x[1]['source'])
for text, meta in paired:
entry = f"\n---\nSource: {Path(meta['source']).name}\n{text}\n"
if char_count + len(entry) > MAX_ANALYSIS_CONTEXT_CHARS:
full_context += "\n[...Truncated due to context limit...]"
console.print("[yellow]⚠ Context limit reached, truncating analysis data.[/yellow]")
break
full_context += entry
char_count += len(entry)
context_str = full_context
chain = get_chain(SYSTEM_PROMPT_ANALYSIS)
response = ""
console.print(f"[dim]Context size: {len(context_str)} chars[/dim]")
async for chunk in chain.astream({ retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K})
"context": context_str, docs = await asyncio.to_thread(retriever.invoke, query)
"question": query, context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs)
"history": history_str
}): chain = get_chain(SYSTEM_PROMPT_SEARCH)
print(chunk, end="")
response += chunk
console.print("\n")
memory.add("user", query) else: # ANALYSIS MODE
memory.add("assistant", response) console.print("📊 ANALYSIS MODE (Full Context)", style="bold magenta")
db_data = await asyncio.to_thread(vectorstore.get)
all_texts = db_data['documents']
all_metas = db_data['metadatas']
if not all_texts:
console.print("No documents found to analyze!", style="red")
continue
full_context = ""
char_count = 0
paired = sorted(zip(all_texts, all_metas), key=lambda x: x[1]['source'])
for text, meta in paired:
entry = f"\n---\nSource: {Path(meta['source']).name}\n{text}\n"
if char_count + len(entry) > MAX_ANALYSIS_CONTEXT_CHARS:
full_context += "\n[...Truncated due to context limit...]"
console.print("⚠ Context limit reached, truncating analysis data.", style="yellow")
break
full_context += entry
char_count += len(entry)
context_str = full_context
chain = get_chain(SYSTEM_PROMPT_ANALYSIS)
response = ""
console.print(f"Context size: {len(context_str)} chars", style="dim")
console.print("Assistant:", style="blue", end=" ")
async for chunk in chain.astream({
"context": context_str,
"question": query,
"history": history_str
}):
print(chunk, end="")
response += chunk
console.print("\n")
memory.add("user", query)
memory.add("assistant", response)
finally: finally:
observer.stop() observer.stop()
@@ -406,5 +385,5 @@ if __name__ == "__main__":
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(main()) loop.run_until_complete(main())
except KeyboardInterrupt: except KeyboardInterrupt:
console.print("\n[yellow]Goodbye![/yellow]") console.print("Goodbye!", style="yellow")
sys.exit(0) sys.exit(0)