diff --git a/.env.example b/.env.example index af6f481..1159ced 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,7 @@ -MD_FOLDER=my_docs +MD_FOLDER=notes EMBEDDING_MODEL=mxbai-embed-large:latest LLM_MODEL=qwen2.5:7b-instruct-q8_0 +OLLAMA_BASE_URL=http://localhost:11434 SYSTEM_PROMPT="You are a precise technical assistant. Cite sources using [filename]. Be concise." diff --git a/.gitignore b/.gitignore index 858124b..cc34fb3 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,4 @@ wheels/ # .env .cache/ -my_docs/ +notes/ diff --git a/main.py b/main.py index ce314bf..cb67230 100644 --- a/main.py +++ b/main.py @@ -7,16 +7,13 @@ import asyncio import re from pathlib import Path from collections import deque -from typing import List, Dict, Tuple +from typing import List, Dict -import torch from dotenv import load_dotenv from rich.console import Console from rich.panel import Panel -from rich.markdown import Markdown from prompt_toolkit import PromptSession from prompt_toolkit.styles import Style -from prompt_toolkit.patch_stdout import patch_stdout from langchain_community.document_loaders import UnstructuredMarkdownLoader from langchain_text_splitters import RecursiveCharacterTextSplitter @@ -32,13 +29,14 @@ from watchdog.events import FileSystemEventHandler # ========================= # CONFIG # ========================= -console = Console() +console = Console(color_system="standard", force_terminal=True) session = PromptSession() load_dotenv() style = Style.from_dict({"prompt": "bold #6a0dad"}) -# --- PROMPTS --- +OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") + SYSTEM_PROMPT_SEARCH = os.getenv("SYSTEM_PROMPT", "You are a precise technical assistant. Cite sources using [filename]. Be concise.") SYSTEM_PROMPT_ANALYSIS = ( "You are an expert tutor and progress evaluator. " @@ -51,7 +49,7 @@ SYSTEM_PROMPT_ANALYSIS = ( USER_PROMPT_TEMPLATE = os.getenv("USER_PROMPT_TEMPLATE", "Previous Conversation:\n{history}\n\nContext from Docs:\n{context}\n\nCurrent Question: {question}") -MD_DIRECTORY = os.getenv("MD_FOLDER", "./notes") +MD_DIRECTORY = os.getenv("MD_FOLDER", "./my_docs") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nomic-embed-text") LLM_MODEL = os.getenv("LLM_MODEL", "llama3") @@ -64,28 +62,11 @@ CHUNK_OVERLAP = 200 TOP_K = 6 COLLECTION_NAME = "md_rag" -# Limit context size for Analysis mode (approx 24k chars ~ 6k tokens) to prevent OOM MAX_ANALYSIS_CONTEXT_CHARS = 24000 BATCH_SIZE = 10 MAX_PARALLEL_FILES = 3 -# ========================= -# GPU SETUP -# ========================= -def setup_gpu(): - if torch.cuda.is_available(): - torch.cuda.set_per_process_memory_fraction(0.95) - device_id = torch.cuda.current_device() - device_name = torch.cuda.get_device_name(device_id) - total_vram = torch.cuda.get_device_properties(device_id).total_memory / (1024**3) - console.print(f"[green]✓ GPU: {device_name} ({total_vram:.1f}GB)[/green]\n") - else: - console.print("[yellow]⚠ CPU mode[/yellow]\n") - console.print("\n") - -setup_gpu() - # ========================= # UTILS & CACHE # ========================= @@ -105,14 +86,12 @@ def save_hash_cache(cache: dict): # ROUTING LOGIC # ========================= def classify_intent(query: str) -> str: - """ - Determines if the user wants a specific search (RAG) or a global assessment. - """ analysis_keywords = [ r"assess my progress", r"eval(uate)? my (learning|knowledge)", r"what have i learned", r"summary of (my )?notes", r"my progress", r"learning path", r"knowledge gap", - r"оцени (мой )?прогресс", r"что я выучил", r"итоги", r"анализ знаний" + r"оцени (мой )?прогресс", r"что я выучил", r"итоги", r"анализ знаний", + r"сегодня урок", r"что я изучил" ] query_lower = query.lower() @@ -137,7 +116,6 @@ def validate_chunk_size(text: str, max_chars: int = MAX_EMBED_CHARS) -> List[str current += sentence else: if current: chunks.append(current.strip()) - # Handle extremely long sentences by word splitting if len(sentence) > max_chars: words = sentence.split() temp = "" @@ -164,7 +142,7 @@ class ChunkProcessor: try: docs = await asyncio.to_thread(UnstructuredMarkdownLoader(file_path).load) except Exception as e: - console.print(f"[red]✗ {Path(file_path).name}: {e}[/red]") + console.print(f"✗ {Path(file_path).name}: {e}", style="red") return [] splitter = RecursiveCharacterTextSplitter( @@ -193,7 +171,7 @@ class ChunkProcessor: await asyncio.to_thread(self.vectorstore.add_documents, docs, ids=ids) return True except Exception as e: - console.print(f"[red]✗ Embed error: {e}[/red]") + console.print(f"✗ Embed error: {e}", style="red") return False async def index_file(self, file_path: str, cache: dict) -> bool: @@ -208,14 +186,14 @@ class ChunkProcessor: try: self.vectorstore._collection.delete(where={"source": file_path}) except: - pass # Collection might be empty + pass for i in range(0, len(chunks), BATCH_SIZE): batch = chunks[i:i + BATCH_SIZE] await self.embed_batch(batch) cache[file_path] = current_hash - console.print(f"[green]✓ {Path(file_path).name} ({len(chunks)} chunks)[/green]") + console.print(f"✓ {Path(file_path).name} ({len(chunks)} chunks)", style="green") return True # ========================= @@ -269,7 +247,11 @@ class ConversationMemory: return "\n".join([f"{m['role'].upper()}: {m['content']}" for m in self.messages]) def get_chain(system_prompt): - llm = ChatOllama(model=LLM_MODEL, temperature=0.2) + llm = ChatOllama( + model=LLM_MODEL, + temperature=0.2, + base_url=OLLAMA_BASE_URL + ) prompt = ChatPromptTemplate.from_messages([ ("system", system_prompt), ("human", USER_PROMPT_TEMPLATE) @@ -291,7 +273,10 @@ async def main(): border_style="cyan" )) - embeddings = OllamaEmbeddings(model=EMBEDDING_MODEL) + embeddings = OllamaEmbeddings( + model=EMBEDDING_MODEL, + base_url=OLLAMA_BASE_URL + ) vectorstore = Chroma( collection_name=COLLECTION_NAME, persist_directory=CHROMA_PATH, @@ -301,14 +286,13 @@ async def main(): processor = ChunkProcessor(vectorstore) cache = load_hash_cache() - console.print("\n[yellow]Checking documents...[/yellow]") + console.print("Checking documents...", style="yellow") files = [ os.path.join(root, file) for root, _, files in os.walk(MD_DIRECTORY) for file in files if file.endswith(".md") ] - # Initial Indexing semaphore = asyncio.Semaphore(MAX_PARALLEL_FILES) async def sem_task(fp): async with semaphore: @@ -322,77 +306,72 @@ async def main(): observer = start_watcher(processor, cache) memory = ConversationMemory() - console.print("[bold green]💬 Ready! Type 'exit' to quit.[/bold green]\n") + console.print("💬 Ready! Type 'exit' to quit.", style="bold green") try: - with patch_stdout(): - while True: - query = await session.prompt_async("> ", style=style) - query = query.strip() - if query.lower() in {"exit", "quit", "q"}: - print("Goodbye!") - break - if not query: continue + while True: + query = await session.prompt_async("> ", style=style) + query = query.strip() + if query.lower() in {"exit", "quit", "q"}: + console.print("Goodbye!", style="yellow") + break + if not query: continue - mode = classify_intent(query) - history_str = memory.get_history() + mode = classify_intent(query) + history_str = memory.get_history() - if mode == "SEARCH": - console.print("[bold blue]🔍 SEARCH MODE (Top-K)[/bold blue]") - - # Standard RAG - retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K}) - docs = await asyncio.to_thread(retriever.invoke, query) - context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs) - - chain = get_chain(SYSTEM_PROMPT_SEARCH) - - else: # ANALYSIS MODE - console.print("[bold magenta]📊 ANALYSIS MODE (Full Context)[/bold magenta]") - - # Fetch ALL documents (limited by size) - # Chroma .get() returns dict with keys: ids, embeddings, documents, metadatas - db_data = await asyncio.to_thread(vectorstore.get) - all_texts = db_data['documents'] - all_metas = db_data['metadatas'] - - if not all_texts: - console.print("[red]No documents found to analyze![/red]") - continue - - # Concatenate content for analysis - full_context = "" - char_count = 0 - - # Sort arbitrarily or by source to group files - paired = sorted(zip(all_texts, all_metas), key=lambda x: x[1]['source']) - - for text, meta in paired: - entry = f"\n---\nSource: {Path(meta['source']).name}\n{text}\n" - if char_count + len(entry) > MAX_ANALYSIS_CONTEXT_CHARS: - full_context += "\n[...Truncated due to context limit...]" - console.print("[yellow]⚠ Context limit reached, truncating analysis data.[/yellow]") - break - full_context += entry - char_count += len(entry) - - context_str = full_context - chain = get_chain(SYSTEM_PROMPT_ANALYSIS) - - response = "" - console.print(f"[dim]Context size: {len(context_str)} chars[/dim]") + if mode == "SEARCH": + console.print("🔍 SEARCH MODE (Top-K)", style="bold blue") - async for chunk in chain.astream({ - "context": context_str, - "question": query, - "history": history_str - }): - print(chunk, end="") - response += chunk - console.print("\n") + retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K}) + docs = await asyncio.to_thread(retriever.invoke, query) + context_str = "\n\n".join(f"[{Path(d.metadata['source']).name}]\n{d.page_content}" for d in docs) + + chain = get_chain(SYSTEM_PROMPT_SEARCH) - memory.add("user", query) - memory.add("assistant", response) + else: # ANALYSIS MODE + console.print("📊 ANALYSIS MODE (Full Context)", style="bold magenta") + + db_data = await asyncio.to_thread(vectorstore.get) + all_texts = db_data['documents'] + all_metas = db_data['metadatas'] + + if not all_texts: + console.print("No documents found to analyze!", style="red") + continue + + full_context = "" + char_count = 0 + + paired = sorted(zip(all_texts, all_metas), key=lambda x: x[1]['source']) + + for text, meta in paired: + entry = f"\n---\nSource: {Path(meta['source']).name}\n{text}\n" + if char_count + len(entry) > MAX_ANALYSIS_CONTEXT_CHARS: + full_context += "\n[...Truncated due to context limit...]" + console.print("⚠ Context limit reached, truncating analysis data.", style="yellow") + break + full_context += entry + char_count += len(entry) + + context_str = full_context + chain = get_chain(SYSTEM_PROMPT_ANALYSIS) + + response = "" + console.print(f"Context size: {len(context_str)} chars", style="dim") + console.print("Assistant:", style="blue", end=" ") + + async for chunk in chain.astream({ + "context": context_str, + "question": query, + "history": history_str + }): + print(chunk, end="") + response += chunk + console.print("\n") + + memory.add("user", query) + memory.add("assistant", response) finally: observer.stop() @@ -406,5 +385,5 @@ if __name__ == "__main__": loop = asyncio.get_event_loop() loop.run_until_complete(main()) except KeyboardInterrupt: - console.print("\n[yellow]Goodbye![/yellow]") + console.print("Goodbye!", style="yellow") sys.exit(0)