04_rag_pipeline.py

Download
python 453 lines 14.2 KB
  1"""
  2Foundation Models - RAG Pipeline Implementation
  3
  4Implements a simple Retrieval-Augmented Generation (RAG) pipeline.
  5Demonstrates document retrieval using TF-IDF and simple embeddings.
  6Shows prompt composition with retrieved context.
  7
  8No external LLM API calls - focuses on retrieval and prompt engineering.
  9"""
 10
 11import re
 12import math
 13from collections import Counter, defaultdict
 14import numpy as np
 15
 16
 17class TFIDFRetriever:
 18    """Simple TF-IDF based document retriever."""
 19
 20    def __init__(self):
 21        self.documents = []
 22        self.vocab = {}
 23        self.idf = {}
 24        self.doc_vectors = []
 25
 26    def tokenize(self, text):
 27        """Simple tokenization: lowercase and split."""
 28        text = text.lower()
 29        text = re.sub(r'[^\w\s]', ' ', text)
 30        return text.split()
 31
 32    def compute_tf(self, tokens):
 33        """Compute term frequency for a document."""
 34        tf = Counter(tokens)
 35        total = len(tokens)
 36
 37        # Normalize
 38        for term in tf:
 39            tf[term] = tf[term] / total
 40
 41        return tf
 42
 43    def build_index(self, documents):
 44        """
 45        Build TF-IDF index from documents.
 46
 47        Args:
 48            documents: List of document strings
 49        """
 50        self.documents = documents
 51        n_docs = len(documents)
 52
 53        # Tokenize all documents
 54        tokenized_docs = [self.tokenize(doc) for doc in documents]
 55
 56        # Build vocabulary
 57        all_terms = set()
 58        for tokens in tokenized_docs:
 59            all_terms.update(tokens)
 60
 61        self.vocab = {term: idx for idx, term in enumerate(sorted(all_terms))}
 62
 63        # Compute IDF
 64        df = Counter()
 65        for tokens in tokenized_docs:
 66            unique_terms = set(tokens)
 67            df.update(unique_terms)
 68
 69        for term in self.vocab:
 70            # IDF = log(N / df(t))
 71            self.idf[term] = math.log(n_docs / (df[term] + 1))
 72
 73        # Compute TF-IDF vectors for all documents
 74        self.doc_vectors = []
 75        for tokens in tokenized_docs:
 76            tf = self.compute_tf(tokens)
 77            vector = np.zeros(len(self.vocab))
 78
 79            for term, freq in tf.items():
 80                if term in self.vocab:
 81                    idx = self.vocab[term]
 82                    vector[idx] = freq * self.idf[term]
 83
 84            self.doc_vectors.append(vector)
 85
 86        print(f"Indexed {n_docs} documents with vocabulary size {len(self.vocab)}")
 87
 88    def retrieve(self, query, top_k=3):
 89        """
 90        Retrieve top-k most relevant documents for query.
 91
 92        Args:
 93            query: Query string
 94            top_k: Number of documents to retrieve
 95
 96        Returns:
 97            List of (doc_idx, score, document) tuples
 98        """
 99        # Compute query vector
100        tokens = self.tokenize(query)
101        tf = self.compute_tf(tokens)
102
103        query_vector = np.zeros(len(self.vocab))
104        for term, freq in tf.items():
105            if term in self.vocab:
106                idx = self.vocab[term]
107                query_vector[idx] = freq * self.idf[term]
108
109        # Compute cosine similarity with all documents
110        scores = []
111        for doc_idx, doc_vector in enumerate(self.doc_vectors):
112            # Cosine similarity
113            dot_product = np.dot(query_vector, doc_vector)
114            query_norm = np.linalg.norm(query_vector)
115            doc_norm = np.linalg.norm(doc_vector)
116
117            if query_norm > 0 and doc_norm > 0:
118                similarity = dot_product / (query_norm * doc_norm)
119            else:
120                similarity = 0.0
121
122            scores.append((doc_idx, similarity, self.documents[doc_idx]))
123
124        # Sort by score and return top-k
125        scores.sort(key=lambda x: x[1], reverse=True)
126        return scores[:top_k]
127
128
129class SimpleEmbeddingRetriever:
130    """Simple embedding-based retriever using random projections."""
131
132    def __init__(self, embedding_dim=128):
133        self.embedding_dim = embedding_dim
134        self.documents = []
135        self.vocab = {}
136        self.word_embeddings = {}
137        self.doc_embeddings = []
138
139    def tokenize(self, text):
140        """Simple tokenization."""
141        text = text.lower()
142        text = re.sub(r'[^\w\s]', ' ', text)
143        return text.split()
144
145    def build_index(self, documents):
146        """Build simple embedding index."""
147        self.documents = documents
148
149        # Build vocabulary
150        all_tokens = []
151        for doc in documents:
152            all_tokens.extend(self.tokenize(doc))
153
154        unique_tokens = set(all_tokens)
155        self.vocab = {token: idx for idx, token in enumerate(sorted(unique_tokens))}
156
157        # Create random word embeddings (in practice, use pretrained)
158        np.random.seed(42)
159        for token in self.vocab:
160            self.word_embeddings[token] = np.random.randn(self.embedding_dim)
161            # Normalize
162            self.word_embeddings[token] /= np.linalg.norm(self.word_embeddings[token])
163
164        # Create document embeddings (average of word embeddings)
165        self.doc_embeddings = []
166        for doc in documents:
167            tokens = self.tokenize(doc)
168            if tokens:
169                embeddings = [self.word_embeddings[t] for t in tokens if t in self.word_embeddings]
170                doc_emb = np.mean(embeddings, axis=0) if embeddings else np.zeros(self.embedding_dim)
171            else:
172                doc_emb = np.zeros(self.embedding_dim)
173
174            self.doc_embeddings.append(doc_emb)
175
176        print(f"Indexed {len(documents)} documents with {self.embedding_dim}-dim embeddings")
177
178    def retrieve(self, query, top_k=3):
179        """Retrieve top-k documents by embedding similarity."""
180        tokens = self.tokenize(query)
181
182        # Compute query embedding
183        embeddings = [self.word_embeddings[t] for t in tokens if t in self.word_embeddings]
184        if embeddings:
185            query_emb = np.mean(embeddings, axis=0)
186        else:
187            query_emb = np.zeros(self.embedding_dim)
188
189        # Compute similarities
190        scores = []
191        for doc_idx, doc_emb in enumerate(self.doc_embeddings):
192            # Cosine similarity
193            similarity = np.dot(query_emb, doc_emb)
194            scores.append((doc_idx, similarity, self.documents[doc_idx]))
195
196        scores.sort(key=lambda x: x[1], reverse=True)
197        return scores[:top_k]
198
199
200class RAGPipeline:
201    """Complete RAG pipeline with retrieval and prompt composition."""
202
203    def __init__(self, retriever):
204        self.retriever = retriever
205
206    def generate_prompt(self, query, context_docs, system_prompt=None):
207        """
208        Compose RAG prompt with retrieved context.
209
210        Args:
211            query: User query
212            context_docs: Retrieved documents
213            system_prompt: Optional system instruction
214
215        Returns:
216            Formatted prompt string
217        """
218        prompt_parts = []
219
220        # System prompt
221        if system_prompt:
222            prompt_parts.append(f"System: {system_prompt}\n")
223
224        # Context
225        prompt_parts.append("Context:\n")
226        for idx, (doc_idx, score, doc) in enumerate(context_docs):
227            prompt_parts.append(f"[{idx+1}] {doc}\n")
228
229        # Query
230        prompt_parts.append(f"\nQuestion: {query}\n")
231        prompt_parts.append("\nAnswer based on the context above:")
232
233        return ''.join(prompt_parts)
234
235    def query(self, query, top_k=3, system_prompt=None):
236        """
237        Full RAG query pipeline.
238
239        Args:
240            query: User question
241            top_k: Number of documents to retrieve
242            system_prompt: Optional system instruction
243
244        Returns:
245            Dictionary with retrieved docs and formatted prompt
246        """
247        # Retrieve relevant documents
248        retrieved = self.retriever.retrieve(query, top_k=top_k)
249
250        # Generate prompt
251        prompt = self.generate_prompt(query, retrieved, system_prompt)
252
253        return {
254            'query': query,
255            'retrieved_docs': retrieved,
256            'prompt': prompt,
257        }
258
259
260# ============================================================
261# Demonstrations
262# ============================================================
263
264def get_sample_documents():
265    """Get sample knowledge base documents."""
266    return [
267        "Python is a high-level programming language known for its simplicity and readability. It was created by Guido van Rossum in 1991.",
268        "Machine learning is a subset of artificial intelligence that enables systems to learn from data without explicit programming.",
269        "Neural networks are computing systems inspired by biological neural networks. They consist of interconnected nodes organized in layers.",
270        "Deep learning uses neural networks with many layers to learn hierarchical representations of data.",
271        "Natural language processing (NLP) is a field of AI focused on enabling computers to understand and generate human language.",
272        "Transformers are a type of neural network architecture introduced in 2017 that use self-attention mechanisms.",
273        "Large language models like GPT are trained on vast amounts of text data to generate human-like text.",
274        "Transfer learning involves taking a pretrained model and fine-tuning it for a specific task.",
275        "Computer vision is a field of AI that enables computers to understand and interpret visual information from images and videos.",
276        "Reinforcement learning is a type of machine learning where agents learn to make decisions by interacting with an environment.",
277    ]
278
279
280def demo_tfidf_retrieval():
281    """Demonstrate TF-IDF based retrieval."""
282    print("=" * 60)
283    print("DEMO 1: TF-IDF Retrieval")
284    print("=" * 60)
285
286    documents = get_sample_documents()
287
288    retriever = TFIDFRetriever()
289    retriever.build_index(documents)
290
291    # Test queries
292    queries = [
293        "What is Python?",
294        "How do neural networks work?",
295        "What are transformers?",
296    ]
297
298    for query in queries:
299        print(f"\nQuery: {query}")
300        print("-" * 60)
301
302        results = retriever.retrieve(query, top_k=3)
303
304        for rank, (doc_idx, score, doc) in enumerate(results, 1):
305            print(f"{rank}. [Score: {score:.4f}] {doc[:80]}...")
306
307
308def demo_embedding_retrieval():
309    """Demonstrate embedding-based retrieval."""
310    print("\n" + "=" * 60)
311    print("DEMO 2: Embedding-Based Retrieval")
312    print("=" * 60)
313
314    documents = get_sample_documents()
315
316    retriever = SimpleEmbeddingRetriever(embedding_dim=64)
317    retriever.build_index(documents)
318
319    queries = [
320        "programming languages",
321        "AI and machine learning",
322        "language models",
323    ]
324
325    for query in queries:
326        print(f"\nQuery: {query}")
327        print("-" * 60)
328
329        results = retriever.retrieve(query, top_k=3)
330
331        for rank, (doc_idx, score, doc) in enumerate(results, 1):
332            print(f"{rank}. [Score: {score:.4f}] {doc[:80]}...")
333
334
335def demo_rag_pipeline():
336    """Demonstrate complete RAG pipeline."""
337    print("\n" + "=" * 60)
338    print("DEMO 3: Complete RAG Pipeline")
339    print("=" * 60)
340
341    documents = get_sample_documents()
342
343    # Use TF-IDF retriever
344    retriever = TFIDFRetriever()
345    retriever.build_index(documents)
346
347    # Create RAG pipeline
348    rag = RAGPipeline(retriever)
349
350    # Test query
351    query = "What is the relationship between deep learning and neural networks?"
352
353    system_prompt = "You are a helpful AI assistant. Answer questions based only on the provided context."
354
355    result = rag.query(query, top_k=3, system_prompt=system_prompt)
356
357    print(f"\nQuery: {result['query']}")
358    print("\n" + "=" * 60)
359    print("Retrieved Documents:")
360    print("=" * 60)
361
362    for rank, (doc_idx, score, doc) in enumerate(result['retrieved_docs'], 1):
363        print(f"\n{rank}. [Score: {score:.4f}]")
364        print(f"   {doc}")
365
366    print("\n" + "=" * 60)
367    print("Generated Prompt:")
368    print("=" * 60)
369    print(result['prompt'])
370
371
372def demo_retrieval_comparison():
373    """Compare TF-IDF vs embedding retrieval."""
374    print("\n" + "=" * 60)
375    print("DEMO 4: Retrieval Method Comparison")
376    print("=" * 60)
377
378    documents = get_sample_documents()
379
380    # Build both retrievers
381    tfidf = TFIDFRetriever()
382    tfidf.build_index(documents)
383
384    embedding = SimpleEmbeddingRetriever(embedding_dim=128)
385    embedding.build_index(documents)
386
387    query = "What is artificial intelligence?"
388
389    print(f"\nQuery: {query}\n")
390
391    # TF-IDF results
392    print("TF-IDF Retrieval:")
393    print("-" * 60)
394    tfidf_results = tfidf.retrieve(query, top_k=3)
395    for rank, (doc_idx, score, doc) in enumerate(tfidf_results, 1):
396        print(f"{rank}. [Score: {score:.4f}] Doc {doc_idx}")
397
398    # Embedding results
399    print("\nEmbedding Retrieval:")
400    print("-" * 60)
401    emb_results = embedding.retrieve(query, top_k=3)
402    for rank, (doc_idx, score, doc) in enumerate(emb_results, 1):
403        print(f"{rank}. [Score: {score:.4f}] Doc {doc_idx}")
404
405
406def demo_prompt_engineering():
407    """Demonstrate different prompt strategies in RAG."""
408    print("\n" + "=" * 60)
409    print("DEMO 5: RAG Prompt Engineering")
410    print("=" * 60)
411
412    documents = get_sample_documents()
413    retriever = TFIDFRetriever()
414    retriever.build_index(documents)
415    rag = RAGPipeline(retriever)
416
417    query = "How does transfer learning work?"
418
419    # Different system prompts
420    prompts = {
421        "Basic": "Answer the question based on the context.",
422        "Detailed": "Provide a detailed answer using only information from the context. If the context doesn't contain the answer, say so.",
423        "Concise": "Give a brief, one-sentence answer based on the context.",
424    }
425
426    for style, system_prompt in prompts.items():
427        print(f"\n{style} Style:")
428        print("-" * 60)
429        result = rag.query(query, top_k=2, system_prompt=system_prompt)
430        print(result['prompt'][:300] + "...\n")
431
432
433if __name__ == "__main__":
434    print("\n" + "=" * 60)
435    print("Foundation Models: RAG Pipeline")
436    print("=" * 60)
437
438    demo_tfidf_retrieval()
439    demo_embedding_retrieval()
440    demo_rag_pipeline()
441    demo_retrieval_comparison()
442    demo_prompt_engineering()
443
444    print("\n" + "=" * 60)
445    print("Key Takeaways:")
446    print("=" * 60)
447    print("1. RAG combines retrieval with generation for grounded answers")
448    print("2. TF-IDF: sparse retrieval based on term importance")
449    print("3. Embeddings: dense retrieval based on semantic similarity")
450    print("4. Prompt engineering: format context and query effectively")
451    print("5. Retrieval quality directly impacts generation quality")
452    print("=" * 60)