Retrieval Augmented Generation with DeepSeek R1 — SkyPilot documentation (original) (raw)

Source: llm/rag

RAG with DeepSeek R1

For the full blog post, please find it here.

As legal document collections grow, traditional keyword search becomes insufficient for finding relevant information. Retrieval-Augmented Generation (RAG) combines the power of vector search with large language models to enable semantic search and intelligent answers.

In particular:

SkyPilot streamlines the development and deployment of RAG systems in any cloud or kubernetes by managing infrastructure and enabling efficient, cost-effective compute resource usage.

In this example, we use legal documents by pile of law as example data to demonstrate RAG capabilities. The system processes a collection of legal texts, including case law, statutes, and legal discussions, to enable semantic search and intelligent question answering. This approach can help legal professionals quickly find relevant precedents, analyze complex legal scenarios, and extract insights from large document collections.

We use Alibaba-NLP/gte-Qwen2-7B-instruct for generating document embeddings and distilled Deepseek R1 (deepseek-ai/DeepSeek-R1-Distill-Llama-8B) for generating final anwsers.

Why SkyPilot: SkyPilot streamlines the process of running such large-scale jobs in the cloud. It abstracts away much of the complexity of managing infrastructure and helps you run compute-intensive tasks efficiently and cost-effectively through managed jobs.

Step 0: Set Up The Environment#

Install the following Prerequisites:

Set up bucket names for storing embeddings and vector database:

export EMBEDDINGS_BUCKET_NAME=sky-rag-embeddings export VECTORDB_BUCKET_NAME=sky-rag-vectordb

Note that these bucket names need to be unique to the entire SkyPilot community.

Convert legal documents into vector representations using Alibaba-NLP/gte-Qwen2-7B-instruct. These embeddings enable semantic search across the document collection.

Launch the embedding computation:

python3 batch_compute_embeddings.py --embedding_bucket_name $EMBEDDINGS_BUCKET_NAME

Here is how the python script launches vLLM with Alibaba-NLP/gte-Qwen2-7B-instruct for embedding generation, where we set each worker to work from START_IDX to END_IDX.

SkyPilot YAML for embedding generation

name: compute-legal-embeddings

resources: accelerators: {L4:1, A100:1}

envs: START_IDX: 0 # Will be overridden by batch_compute_vectors.py END_IDX: 10000 # Will be overridden by batch_compute_vectors.py MODEL_NAME: "Alibaba-NLP/gte-Qwen2-7B-instruct" EMBEDDINGS_BUCKET_NAME: sky-rag-embeddings # Bucket name for storing embeddings

file_mounts: /output: name: ${EMBEDDINGS_BUCKET_NAME} mode: MOUNT

setup: | pip install torch==2.5.1 vllm==0.6.6.post1 ...

envs: MODEL_NAME: "Alibaba-NLP/gte-Qwen2-7B-instruct"

run: | python -m vllm.entrypoints.openai.api_server
--host 0.0.0.0
--model $MODEL_NAME
--max-model-len 3072
--task embed &

python scripts/compute_embeddings.py
--start-idx $START_IDX
--end-idx $END_IDX
--chunk-size 2048
--chunk-overlap 512
--vllm-endpoint http://localhost:8000

This automatically launches 10 SkyPilot managed jobs on L4 GPUs to processe documents from the Pile of Law dataset and computes embeddings in batches:

Processing documents: 100%|██████████| 1000/1000 [00:45<00:00, 22.05it/s] Saving embeddings to: embeddings_0_1000.parquet ...

We leverage SkyPilot’s managed jobs feature to enable parallel processing across multiple regions and cloud providers. SkyPilot handles job state management and automatic recovery from failures when using spot instances. Managed jobs are cost-efficient and streamline the processing of the partitioned dataset. You can check all the jobs by running sky jobs dashboard.

job dashboard

All generated embeddings are stored efficiently in parquet format within a cloud storage bucket.

Step 2: Build RAG with Vector Database#

After computing embeddings, construct a ChromaDB vector database for efficient similarity search:

sky launch build_rag.yaml --env EMBEDDINGS_BUCKET_NAME=$EMBEDDINGS_BUCKET_NAME --env VECTORDB_BUCKET_NAME=$VECTORDB_BUCKET_NAME

The process builds the database in batches:

Loading embeddings from: embeddings_0_1000.parquet Adding vectors to ChromaDB: 100%|██████████| 1000/1000 [00:12<00:00, 81.97it/s] ...

Step 3: Serve the RAG#

Deploy the RAG service to handle queries and generate answers:

sky launch -c legal-rag serve_rag.yaml --env VECTORDB_BUCKET_NAME=$VECTORDB_BUCKET_NAME

Or use Sky Serve for managed deployment:

sky serve up -n legal-rag serve_rag.yaml --env VECTORDB_BUCKET_NAME=$VECTORDB_BUCKET_NAME

To query the system, get the endpoint:

sky serve status legal-rag --endpoint

You can visit the website and input your query there! A few queries to try out:

I want to break my lease. my landlord doesn’t allow me to do that. My employer has not provided the final paycheck after termination.

Disclaimer#

This document provides instruction for building a RAG system with SkyPilot. The system and its outputs should not be considered as legal advice. Please consult qualified legal professionals for any legal matters.

Included files#

batch_compute_embeddings.py

""" Use skypilot to launch managed jobs that will run the embedding calculation for RAG.

This script is responsible for splitting the input dataset up among several workers, then using skypilot to launch managed jobs for each worker. We use compute_embeddings.yaml to define the managed job info. """

#!/usr/bin/env python3

import argparse

import sky

def calculate_job_range(start_idx: int, end_idx: int, job_rank: int, total_jobs: int) -> tuple[int, int]: """Calculate the range of indices this job should process.

Args:
    start_idx: Global start index
    end_idx: Global end index
    job_rank: Current job's rank (0-based)
    total_jobs: Total number of jobs
    
Returns:
    Tuple of [job_start_idx, job_end_idx)
"""
total_range = end_idx - start_idx
chunk_size = total_range // total_jobs
remainder = total_range % total_jobs

# Distribute remainder across first few jobs
job_start = start_idx + (job_rank * chunk_size) + min(job_rank, remainder)
if job_rank < remainder:
    chunk_size += 1
job_end = job_start + chunk_size

return job_start, job_end

def main(): parser = argparse.ArgumentParser( description='Launch batch RAG embedding computation jobs') parser.add_argument('--start-idx', type=int, default=0, help='Global start index in dataset') parser.add_argument( '--end-idx', type=int, # this is the last index of the reddit post dataset default=109740, help='Global end index in dataset, not inclusive') parser.add_argument('--num-jobs', type=int, default=1, help='Number of jobs to partition the work across') parser.add_argument("--embedding_bucket_name", type=str, default="sky-rag-embeddings", help="Name of the bucket to store embeddings")

args = parser.parse_args()

# Load the task template
task = sky.Task.from_yaml('compute_embeddings.yaml')

# Launch jobs for each partition
for job_rank in range(args.num_jobs):
    # Calculate index range for this job
    job_start, job_end = calculate_job_range(args.start_idx, args.end_idx,
                                             job_rank, args.num_jobs)

    # Update environment variables for this job
    task_copy = task.update_envs({
        'START_IDX': job_start,
        'END_IDX': job_end,
        'EMBEDDINGS_BUCKET_NAME': args.embedding_bucket_name,
    })

    sky.jobs.launch(task_copy, name=f'rag-compute-{job_start}-{job_end}')

if name == 'main': main()

build_rag.yaml

name: build-legal-rag

workdir: .

resources: memory: 32+ # Need more memory for merging embeddings infra: aws

envs: EMBEDDINGS_BUCKET_NAME: sky-rag-embeddings VECTORDB_BUCKET_NAME: sky-rag-vectordb

file_mounts: /embeddings: name: ${EMBEDDINGS_BUCKET_NAME} # this needs to be the same as the output in compute_embeddings.yaml mode: MOUNT

/vectordb: name: ${VECTORDB_BUCKET_NAME} mode: MOUNT

setup: | pip install chromadb pandas tqdm pyarrow

run: | python scripts/build_rag.py
--collection-name legal_docs
--persist-dir /vectordb/chroma
--embeddings-dir /embeddings
--batch-size 1000

compute_embeddings.yaml

name: compute-law-embeddings

workdir: .

resources: accelerators: L4: 1 memory: 32+ any_of: - use_spot: true - use_spot: false

envs: START_IDX: 0 # Will be overridden by batch_compute_vectors.py END_IDX: 10000 # Will be overridden by batch_compute_vectors.py MODEL_NAME: "Alibaba-NLP/gte-Qwen2-7B-instruct" EMBEDDINGS_BUCKET_NAME: sky-rag-embeddings # Bucket name for storing embeddings

file_mounts: /output: name: ${EMBEDDINGS_BUCKET_NAME} mode: MOUNT

setup: |

Install dependencies for vLLM

pip install transformers==4.48.1 vllm==0.6.6.post1

Install dependencies for embedding computation

pip install numpy pandas requests tqdm datasets pip install nltk hf_transfer

run: |

Initialize and download the model

HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download --local-dir /tmp/model $MODEL_NAME

Start vLLM service in background

python -m vllm.entrypoints.openai.api_server
--host 0.0.0.0
--model /tmp/model
--max-model-len 3072
--task embed &

Wait for vLLM to be ready by checking the health endpoint

echo "Waiting for vLLM service to be ready..." while ! curl -s http://localhost:8000/health > /dev/null; do sleep 5 echo "Still waiting for vLLM service..." done echo "vLLM service is ready!"

Process the assigned range of documents

echo "Processing documents from STARTIDXtoSTART_IDX to STARTIDXtoEND_IDX"

python scripts/compute_embeddings.py
--output-path "/output/embeddings_${START_IDX}_${END_IDX}.parquet"
--start-idx $START_IDX
--end-idx $END_IDX
--chunk-size 2048
--chunk-overlap 512
--vllm-endpoint http://localhost:8000
--batch-size 32

Clean up vLLM service

pkill -f "python -m vllm.entrypoints.openai.api_server" echo "vLLM service has been stopped"

scripts/build_rag.py

""" This script is responsible for building the vector database from the mounted bucket and saving it to another mounted bucket. """

import argparse import base64 from concurrent.futures import as_completed from concurrent.futures import ProcessPoolExecutor import glob import logging import multiprocessing import os import pickle import shutil import tempfile

import chromadb import numpy as np import pandas as pd from tqdm import tqdm

logging.basicConfig(level=logging.INFO) logger = logging.getLogger(name)

def list_local_parquet_files(mount_path: str, prefix: str) -> list: """List all parquet files in the mounted directory.""" search_path = os.path.join(mount_path, prefix, '**/*.parquet') parquet_files = glob.glob(search_path, recursive=True) return parquet_files

def process_parquet_file(args): """Process a single parquet file and return the processed data.""" parquet_file, batch_size = args try: results = [] df = pd.read_parquet(parquet_file)

    # Process in batches
    for i in range(0, len(df), batch_size):
        batch_df = df.iloc[i:i + batch_size]
        # Extract data from DataFrame
        ids = [str(idx) for idx in batch_df['id']]
        embeddings = [pickle.loads(emb) for emb in batch_df['embedding']]
        documents = batch_df['content'].tolist(
        )  # Content goes to documents
        # Create metadata from the available fields (excluding content)
        metadatas = [{
            'name': row['name'],
            'split': row['split'],
            'source': row['source'],
        } for _, row in batch_df.iterrows()]
        results.append((ids, embeddings, documents, metadatas))

    return results
except Exception as e:
    logger.error(f'Error processing file {parquet_file}: {str(e)}')
    return None

def main(): parser = argparse.ArgumentParser( description='Build ChromaDB from mounted parquet files') parser.add_argument('--collection-name', type=str, default='rag_embeddings', help='ChromaDB collection name') parser.add_argument('--persist-dir', type=str, default='/vectordb/chroma', help='Directory to persist ChromaDB') parser.add_argument( '--batch-size', type=int, default=1000, help='Batch size for processing, this needs to fit in memory') parser.add_argument('--embeddings-dir', type=str, default='/embeddings', help='Path to mounted bucket containing parquet files') parser.add_argument( '--prefix', type=str, default='', help='Prefix path within mounted bucket to search for parquet files')

args = parser.parse_args()

# Create a temporary directory for building the database. The
# mounted bucket does not support append operation, so build in
# the tmpdir and then copy it to the final location.
with tempfile.TemporaryDirectory() as temp_dir:
    logger.info(f'Using temporary directory: {temp_dir}')

    # Initialize ChromaDB in temporary directory
    client = chromadb.PersistentClient(path=temp_dir)

    # Create or get collection for chromadb
    # it attempts to create a collection with the same name
    # if it already exists, it will get the collection
    try:
        collection = client.create_collection(
            name=args.collection_name,
            metadata={'description': 'RAG embeddings from legal documents'})
        logger.info(f'Created new collection: {args.collection_name}')
    except ValueError:
        collection = client.get_collection(name=args.collection_name)
        logger.info(f'Using existing collection: {args.collection_name}')

    # List parquet files from mounted directory
    parquet_files = list_local_parquet_files(args.embeddings_dir,
                                             args.prefix)
    logger.info(f'Found {len(parquet_files)} parquet files')

    # Process files in parallel
    max_workers = max(1,
                      multiprocessing.cpu_count() - 1)  # Leave one CPU free
    logger.info(f'Processing files using {max_workers} workers')

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        # Submit all files for processing
        future_to_file = {
            executor.submit(process_parquet_file, (file, args.batch_size)):
            file for file in parquet_files
        }

        # Process results as they complete
        for future in tqdm(as_completed(future_to_file),
                           total=len(parquet_files),
                           desc='Processing files'):
            file = future_to_file[future]
            try:
                results = future.result()
                if results:
                    for ids, embeddings, documents, metadatas in results:
                        collection.add(ids=list(ids),
                                       embeddings=list(embeddings),
                                       documents=list(documents),
                                       metadatas=list(metadatas))
            except Exception as e:
                logger.error(f'Error processing file {file}: {str(e)}')
                continue

    logger.info('Vector database build complete!')
    logger.info(f'Total documents in collection: {collection.count()}')

    # Copy the completed database to the final location
    logger.info(f'Copying database to final location: {args.persist_dir}')
    if os.path.exists(args.persist_dir):
        logger.info('Removing existing database directory')
        shutil.rmtree(args.persist_dir)
    shutil.copytree(temp_dir, args.persist_dir)
    logger.info('Database copy complete!')

if name == 'main': main()

scripts/compute_embeddings.py

""" Script to compute embeddings for Pile of Law dataset using Alibaba-NLP/gte-Qwen2-7B-instruct through vLLM. """

import argparse import logging import os from pathlib import Path import pickle import shutil import time from typing import Dict, List, Tuple

from datasets import load_dataset import nltk import numpy as np import pandas as pd import requests from tqdm import tqdm

Configure logging

logging.basicConfig(level=logging.INFO) logger = logging.getLogger(name)

Initialize NLTK to chunk documents

try: nltk.data.find('tokenizers/punkt') except LookupError: logger.info('Downloading NLTK punkt tokenizer...') nltk.download('punkt') nltk.download('punkt_tab') logger.info('Download complete')

def load_law_documents(start_idx: int = 0, end_idx: int = 1000) -> List[Dict]: """Load documents from Pile of Law dataset.

Args:
    start_idx: Starting index in dataset
    end_idx: Ending index in dataset
    
Returns:
    List of documents
"""
dataset = load_dataset('pile-of-law/pile-of-law',
                       'all',
                       split='train',
                       streaming=True,
                       trust_remote_code=True)

documents = []
for idx, doc in enumerate(
        dataset.skip(start_idx).take(end_idx - start_idx)):
    documents.append({
        'id': f"{idx + start_idx}",
        'name': doc['url'],
        'text': doc['text'],
        'split': 'train',
        'source': 'r_legaladvice',
        'created_timestamp': doc['created_timestamp'],
        'downloaded_timestamp': doc['downloaded_timestamp'],
        'url': doc['url']
    })

    if (idx + 1) % 100 == 0:
        logger.info(f'Loaded {idx + 1} documents')

return documents

def chunk_document(document, chunk_size=512, overlap=50, start_chunk_idx=0) -> Tuple[List[Dict], int]: """Split document into overlapping chunks using sentence-aware splitting.

Args:
    document: The document to chunk
    chunk_size: Maximum size of each chunk in characters
    overlap: Number of characters to overlap between chunks
    start_chunk_idx: Starting index for global chunk counting
    
Returns:
    List of chunks and the next available chunk index
"""
text = document['text']
chunks = []
chunk_idx = start_chunk_idx

# Split into sentences first
sentences = nltk.sent_tokenize(text)

current_chunk = []
current_length = 0

for sentence in sentences:
    sentence_len = len(sentence)

    # If adding this sentence would exceed chunk size, save current chunk
    if current_length + sentence_len > chunk_size and current_chunk:
        chunk_text = ' '.join(current_chunk)
        chunks.append({
            'id': document['id'] + '_' + str(chunk_idx),
            'name': document['name'],
            'content': document['text'],  # Store full document content
            'chunk_text': chunk_text.strip(),  # Store the specific chunk
            'chunk_start': len(' '.join(
                current_chunk[:-(2 if overlap > 0 else 0)])) if overlap > 0
                           else 0,  # Approximate start position
            'split': document['split'],
            'source': document['source'],
            'document_id': document['id'],
            'document_url': document['url'],
            'document_created_timestamp': document['created_timestamp'],
            'document_downloaded_timestamp':
                document['downloaded_timestamp']
        })
        chunk_idx += 1

        # Keep last few sentences for overlap
        overlap_text = ' '.join(current_chunk[-2:])  # Keep last 2 sentences
        current_chunk = [overlap_text] if overlap > 0 else []
        current_length = len(overlap_text) if overlap > 0 else 0

    current_chunk.append(sentence)
    current_length += sentence_len + 1  # +1 for space

# Add the last chunk if it's not empty
if current_chunk:
    chunk_text = ' '.join(current_chunk)
    chunks.append({
        'id': document['id'] + '_' + str(chunk_idx),
        'name': document['name'],
        'content': document['text'],  # Store full document content
        'chunk_text': chunk_text.strip(),  # Store the specific chunk
        'chunk_start': len(' '.join(
            current_chunk[:-(2 if overlap > 0 else 0)]))
                       if overlap > 0 else 0,  # Approximate start position
        'split': document['split'],
        'source': document['source'],
        'document_id': document['id'],
        'document_url': document['url'],
        'document_created_timestamp': document['created_timestamp'],
        'document_downloaded_timestamp': document['downloaded_timestamp']
    })
    chunk_idx += 1

return chunks, chunk_idx

def compute_embeddings_batch(chunks: List[Dict], vllm_endpoint: str, output_path: str, batch_size: int = 32, partition_size: int = 1000) -> None: """Compute embeddings for document chunks using DeepSeek R1 and save in partitions.

Args:
    chunks: List of document chunks
    vllm_endpoint: Endpoint for vLLM service
    output_path: Path to save embeddings
"""
current_partition = []
partition_counter = 0

# Process in batches
for i in tqdm(range(0, len(chunks), batch_size),
              desc='Computing embeddings'):
    batch = chunks[i:i + batch_size]

    # Create prompt for each chunk - simplified prompt
    prompts = [chunk['content'] for chunk in batch]

    try:
        # Print request payload for debugging
        request_payload = {
            "model": "/tmp/model",
            # because this is loaded from the mounted directory
            "input": prompts
        }

        response = requests.post(f"{vllm_endpoint}/v1/embeddings",
                                 json=request_payload,
                                 timeout=60)

        response.raise_for_status()

        # Extract embeddings - updated response parsing
        result = response.json()

        if 'data' not in result:
            raise ValueError(f"Unexpected response format: {result}")

        embeddings = [item['embedding'] for item in result['data']]

        # Combine embeddings with metadata
        for chunk, embedding in zip(batch, embeddings):
            current_partition.append({
                'id': chunk['id'],
                'name': chunk['name'],
                'content': chunk['content'],
                'chunk_text': chunk['chunk_text'],
                'chunk_start': chunk['chunk_start'],
                'split': chunk['split'],
                'source': chunk['source'],
                'embedding': pickle.dumps(np.array(embedding)),
                # Include document metadata
                'document_id': chunk['document_id'],
                'document_url': chunk['document_url'],
                'document_created_timestamp':
                    chunk['document_created_timestamp'],
                'document_downloaded_timestamp':
                    chunk['document_downloaded_timestamp']
            })

        # Save partition when it reaches the desired size
        if len(current_partition) >= partition_size:
            save_partition(current_partition, output_path,
                           partition_counter)
            partition_counter += 1
            current_partition = []

    except Exception as e:
        logger.error(f"Error computing embeddings for batch: {str(e)}")
        time.sleep(5)
        continue

# Save any remaining embeddings in the final partition
if current_partition:
    save_partition(current_partition, output_path, partition_counter)

def save_partition(results: List[Dict], output_path: str, partition_counter: int) -> None: """Save a partition of embeddings to a parquet file with atomic write.

Args:
    results: List of embeddings
    output_path: Path to save embeddings
    partition_counter: Partition counter
"""
if not results:
    return

df = pd.DataFrame(results)
final_path = f'{output_path}_part_{partition_counter}.parquet'
temp_path = f'/tmp/embeddings_{partition_counter}.tmp'

# Write to temporary file first
df.to_parquet(temp_path, engine='pyarrow', index=False)

# Copy from temp to final destination
os.makedirs(os.path.dirname(final_path), exist_ok=True)
shutil.copy2(temp_path, final_path)
os.remove(temp_path)  # Clean up temp file

logger.info(
    f'Saved partition {partition_counter} to {final_path} with {len(df)} rows'
)

def main(): parser = argparse.ArgumentParser( description='Compute embeddings for Pile of Law dataset') parser.add_argument('--output-path', type=str, required=True, help='Path to save embeddings parquet file') parser.add_argument('--start-idx', type=int, default=0, help='Starting index in dataset') parser.add_argument('--end-idx', type=int, default=1000, help='Ending index in dataset') parser.add_argument('--chunk-size', type=int, default=512, help='Size of document chunks') parser.add_argument('--chunk-overlap', type=int, default=50, help='Overlap between chunks') parser.add_argument('--vllm-endpoint', type=str, required=True, help='Endpoint for vLLM service') parser.add_argument('--batch-size', type=int, default=32, help='Batch size for computing embeddings') parser.add_argument('--partition-size', type=int, default=1000, help='Number of embeddings per partition file')

args = parser.parse_args()

# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)

# Load documents
logger.info('Loading documents from Pile of Law dataset...')
documents = load_law_documents(args.start_idx, args.end_idx)
logger.info(f'Loaded {len(documents)} documents')

# Chunk documents with global counter
logger.info('Chunking documents...')
chunks = []
next_chunk_idx = 0  # Initialize global chunk counter
for doc in documents:
    doc_chunks, next_chunk_idx = chunk_document(doc, args.chunk_size,
                                                args.chunk_overlap,
                                                next_chunk_idx)
    chunks.extend(doc_chunks)
logger.info(f'Created {len(chunks)} chunks')

# Compute embeddings and save in partitions
logger.info('Computing embeddings...')
compute_embeddings_batch(chunks, args.vllm_endpoint, args.output_path,
                         args.batch_size, args.partition_size)
logger.info('Finished computing and saving embeddings')

if name == 'main': main()

scripts/serve_rag.py

""" Script to serve RAG system combining vector search with DeepSeek R1. """

import argparse import logging import os import pickle import time from typing import Any, Dict, List, Optional import uuid

import chromadb from fastapi import FastAPI from fastapi import HTTPException from fastapi.responses import HTMLResponse import numpy as np from pydantic import BaseModel import requests import torch import uvicorn

Configure logging

logging.basicConfig(level=logging.INFO) logger = logging.getLogger(name)

Initialize FastAPI app

app = FastAPI(title='RAG System with DeepSeek R1')

Global variables

collection = None generator_endpoint = None # For text generation embed_endpoint = None # For embeddings

Dictionary to store in-progress LLM queries

active_requests = {}

class QueryRequest(BaseModel): query: str n_results: Optional[int] = 3 temperature: Optional[float] = 0.7

class DocumentsOnlyRequest(BaseModel): query: str n_results: Optional[int] = 3

class StartLLMRequest(BaseModel): request_id: str temperature: Optional[float] = 0.7

class SearchResult(BaseModel): content: str name: str split: str source: str similarity: float

class RAGResponse(BaseModel): answer: str sources: List[SearchResult] thinking_process: str # Add thinking process to response

class DocumentsOnlyResponse(BaseModel): sources: List[SearchResult] request_id: str

class LLMStatusResponse(BaseModel): status: str # "pending", "completed", "error" answer: Optional[str] = None thinking_process: Optional[str] = None error: Optional[str] = None

def encode_query(query: str) -> np.ndarray: """Encode query text using vLLM embeddings endpoint.""" global embed_endpoint

try:
    response = requests.post(f"{embed_endpoint}/v1/embeddings",
                             json={
                                 "model": "/tmp/embedding_model",
                                 "input": [query]
                             },
                             timeout=30)
    response.raise_for_status()

    result = response.json()
    if 'data' not in result:
        raise ValueError(f"Unexpected response format: {result}")

    return np.array(result['data'][0]['embedding'])

except Exception as e:
    logger.error(f"Error computing query embedding: {str(e)}")
    raise HTTPException(status_code=500,
                        detail="Error computing query embedding")

def query_collection(query_embedding: np.ndarray, n_results: int = 10) -> List[SearchResult]: """Query the collection and return top matches.""" global collection

# Request more results initially to account for potential duplicates
max_results = min(n_results * 2, 20)  # Get more results but cap at 20

results = collection.query(query_embeddings=[query_embedding.tolist()],
                           n_results=max_results,
                           include=['metadatas', 'distances', 'documents'])

# Get results
documents = results['documents'][0]
metadatas = results['metadatas'][0]
distances = results['distances'][0]

# Convert distances to similarities
similarities = [1 - (d / 2) for d in distances]

# Create a set to track unique content
seen_content = set()
unique_results = []

for doc, meta, sim in zip(documents, metadatas, similarities):
    # Use content as the uniqueness key
    if doc not in seen_content:
        seen_content.add(doc)
        logger.info(f"Found {meta} with similarity {sim}")
        logger.info(f"Content: {doc}")
        unique_results.append((doc, meta, sim))

        # Break if we have enough unique results
        if len(unique_results) >= n_results:
            break

return [
    SearchResult(content=doc,
                 name=meta['name'],
                 split=meta['split'],
                 source=meta['source'],
                 similarity=sim) for doc, meta, sim in unique_results
]

def generate_prompt(query: str, context_docs: List[SearchResult]) -> str: """Generate prompt for DeepSeek R1.""" # Format context with clear document boundaries context = "\n\n".join([ f"[Document {i+1} begin]\nSource: {doc.source}\nContent: {doc.content}\n[Document {i+1} end]" for i, doc in enumerate(context_docs) ])

return f"""# The following contents are search results from legal documents and related discussions:

{context}

You are a helpful AI assistant analyzing legal documents and related content. When responding, please follow these guidelines:

First, explain your thinking process between tags. Then provide your final answer after the thinking process.

Question:

{query}

Let's approach this step by step:"""

async def query_llm(prompt: str, temperature: float = 0.7) -> tuple[str, str]: """Query DeepSeek R1 through vLLM endpoint and return thinking process and answer.""" global generator_endpoint

try:
    response = requests.post(f"{generator_endpoint}/v1/chat/completions",
                             json={
                                 "model": "/tmp/generation_model",
                                 "messages": [{
                                     "role": "user",
                                     "content": prompt
                                 }],
                                 "temperature": temperature,
                                 "max_tokens": 2048,
                                 "stop": None
                             },
                             timeout=120)
    response.raise_for_status()

    logger.info(f"Response: {response.json()}")

    full_response = response.json(
    )['choices'][0]['message']['content'].strip()

    # Split response into thinking process and answer
    parts = full_response.split("</think>")
    if len(parts) > 1:
        thinking = parts[0].replace("<think>", "").strip()
        answer = parts[1].strip()
    else:
        thinking = ""
        answer = full_response

    return thinking, answer

except Exception as e:
    logger.error(f"Error querying LLM: {str(e)}")
    raise HTTPException(status_code=500,
                        detail="Error querying language model")

@app.post('/rag', response_model=RAGResponse) async def rag_query(request: QueryRequest): """RAG endpoint combining vector search with DeepSeek R1.""" try: # Encode query query_embedding = encode_query(request.query)

    # Get relevant documents
    results = query_collection(query_embedding, request.n_results)

    # Generate prompt
    prompt = generate_prompt(request.query, results)

    # Get LLM response
    thinking, answer = await query_llm(prompt, request.temperature)

    return RAGResponse(answer=answer,
                       sources=results,
                       thinking_process=thinking)

except Exception as e:
    logger.error(f"Error processing RAG query: {str(e)}")
    raise HTTPException(status_code=500, detail=str(e))

@app.get('/health') async def health_check(): """Health check endpoint.""" return { 'status': 'healthy', 'collection_size': collection.count() if collection else 0 }

@app.get('/', response_class=HTMLResponse) async def get_search_page(): """Serve a simple search interface.""" template_path = os.path.join(os.path.dirname(file), 'templates', 'index.html') try: with open(template_path, 'r') as f: return f.read() except FileNotFoundError: raise HTTPException( status_code=500, detail=f"Template file not found at {template_path}")

@app.post('/documents', response_model=DocumentsOnlyResponse) async def get_documents(request: DocumentsOnlyRequest): """Get relevant documents for a query without LLM processing.""" try: # Encode query query_embedding = encode_query(request.query)

    # Get relevant documents
    results = query_collection(query_embedding, request.n_results)

    # Generate a unique request ID
    request_id = str(uuid.uuid4())

    # Store the request data for later LLM processing
    active_requests[request_id] = {
        "query": request.query,
        "results": results,
        "status": "documents_ready",
        "timestamp": time.time()
    }

    # Clean up old requests (older than 30 minutes)
    current_time = time.time()
    expired_requests = [
        req_id for req_id, data in active_requests.items()
        if current_time - data["timestamp"] > 1800
    ]
    for req_id in expired_requests:
        active_requests.pop(req_id, None)

    return DocumentsOnlyResponse(sources=results, request_id=request_id)

except Exception as e:
    logger.error(f"Error retrieving documents: {str(e)}")
    raise HTTPException(status_code=500, detail=str(e))

@app.post('/process_llm', response_model=LLMStatusResponse) async def process_llm(request: StartLLMRequest): """Process a query with the LLM using previously retrieved documents.""" request_id = request.request_id

# Check if the request exists and is ready for LLM processing
if request_id not in active_requests or active_requests[request_id][
        "status"] != "documents_ready":
    raise HTTPException(status_code=404,
                        detail="Request not found or documents not ready")

# Mark the request as in progress
active_requests[request_id]["status"] = "llm_processing"

try:
    # Get stored data
    query = active_requests[request_id]["query"]
    results = active_requests[request_id]["results"]

    # Generate prompt
    prompt = generate_prompt(query, results)

    # Get LLM response
    thinking, answer = await query_llm(prompt, request.temperature)

    # Store the response and mark as completed
    active_requests[request_id]["status"] = "completed"
    active_requests[request_id]["thinking"] = thinking
    active_requests[request_id]["answer"] = answer
    active_requests[request_id]["timestamp"] = time.time()

    return LLMStatusResponse(status="completed",
                             answer=answer,
                             thinking_process=thinking)

except Exception as e:
    # Mark as error
    active_requests[request_id]["status"] = "error"
    active_requests[request_id]["error"] = str(e)
    active_requests[request_id]["timestamp"] = time.time()

    logger.error(f"Error processing LLM request: {str(e)}")
    return LLMStatusResponse(status="error", error=str(e))

@app.get('/llm_status/{request_id}', response_model=LLMStatusResponse) async def get_llm_status(request_id: str): """Get the status of an LLM request.""" if request_id not in active_requests: raise HTTPException(status_code=404, detail="Request not found")

request_data = active_requests[request_id]

if request_data["status"] == "completed":
    return LLMStatusResponse(status="completed",
                             answer=request_data["answer"],
                             thinking_process=request_data["thinking"])
elif request_data["status"] == "error":
    return LLMStatusResponse(status="error",
                             error=request_data.get("error",
                                                    "Unknown error"))
else:
    return LLMStatusResponse(status="pending")

def main(): parser = argparse.ArgumentParser(description='Serve RAG system') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to serve on') parser.add_argument('--port', type=int, default=8000, help='Port to serve on') parser.add_argument('--collection-name', type=str, default='legal_docs', help='ChromaDB collection name') parser.add_argument('--persist-dir', type=str, default='/vectordb/chroma', help='Directory where ChromaDB is persisted') parser.add_argument('--generator-endpoint', type=str, required=True, help='Endpoint for text generation service') parser.add_argument('--embed-endpoint', type=str, required=True, help='Endpoint for embeddings service')

args = parser.parse_args()

# Initialize global variables
global collection, generator_endpoint, embed_endpoint

# Set endpoints
generator_endpoint = args.generator_endpoint.rstrip('/')
embed_endpoint = args.embed_endpoint.rstrip('/')

# Initialize ChromaDB
logger.info(f'Connecting to ChromaDB at {args.persist_dir}')
client = chromadb.PersistentClient(path=args.persist_dir)

try:
    collection = client.get_collection(name=args.collection_name)
    logger.info(f'Connected to collection: {args.collection_name}')
    logger.info(f'Total documents in collection: {collection.count()}')
except ValueError as e:
    logger.error(f'Error: {str(e)}')
    logger.error(
        'Make sure the collection exists and the persist_dir is correct.')
    raise

# Start server
uvicorn.run(app, host=args.host, port=args.port)

if name == 'main': main()

scripts/templates/index.html

SkyPilot Legal RAG

SkyPilot Legal RAG

Ask

Searching for documents...

DeepSeek is thinking...

    <script>
    function escapeHtml(unsafe) {
        return unsafe
            .replace(/&/g, "&amp;")
            .replace(/</g, "&lt;")
            .replace(/>/g, "&gt;")
            .replace(/"/g, "&quot;")
            .replace(/'/g, "&#039;");
    }

    function highlightSource(docNumber) {
        // Remove previous highlights
        document.querySelectorAll('.highlighted-source').forEach(el => {
            el.classList.remove('highlighted-source');
        });
        
        // Add highlight to clicked source
        const sourceElement = document.querySelector(`[data-doc-number="${docNumber}"]`);
        if (sourceElement) {
            sourceElement.classList.add('highlighted-source');
            sourceElement.scrollIntoView({ behavior: 'smooth', block: 'center' });
        }
    }

    function processCitations(text) {
        // Handle both [citation:X] and Document X formats
        return text
            .replace(/\[citation:(\d+)\]/g, (match, docNumber) => {
                return `<span class="citation" onclick="highlightSource(${docNumber})">[${docNumber}]</span>`;
            })
            .replace(/Document (\d+)/g, (match, docNumber) => {
                return `<span class="citation" onclick="highlightSource(${docNumber})">Document ${docNumber}</span>`;
            });
    }

    async function search() {
        const searchInput = document.getElementById('searchInput');
        const resultsDiv = document.getElementById('results');
        const thinkingIndicator = document.getElementById('thinking-indicator');
        const documentSearchIndicator = document.getElementById('document-search-indicator');
        
        if (!searchInput.value.trim()) return;
        
        // Clear previous results and show document search indicator
        resultsDiv.innerHTML = '';
        documentSearchIndicator.style.display = 'flex';
        thinkingIndicator.style.display = 'none';
        
        // Step 1: Get documents first
        try {
            // First call to get the documents
            const docsResponse = await fetch('/documents', {
                method: 'POST',
                headers: {
                    'Content-Type': 'application/json',
                    'Accept': 'application/json'
                },
                body: JSON.stringify({
                    query: searchInput.value.trim(),
                    n_results: 10
                })
            });
            
            if (!docsResponse.ok) {
                const errorData = await docsResponse.json();
                throw new Error(errorData.detail || 'Failed to retrieve documents');
            }
            
            const docsResult = await docsResponse.json();
            const requestId = docsResult.request_id;
            
            // Hide document search indicator and show DeepSeek indicator
            documentSearchIndicator.style.display = 'none';
            thinkingIndicator.style.display = 'flex';
            
            // Display the documents first
            let sourcesHtml = '<div class="result-section"><h2 class="section-title">Source Documents</h2>';
            docsResult.sources.forEach((source, index) => {
                sourcesHtml += `
                    <div class="source-document" data-doc-number="${index + 1}">
                        <div class="source-header">Source: ${escapeHtml(source.source)}</div>
                        <div class="source-url">URL: ${escapeHtml(source.name)}</div>
                        <div>${escapeHtml(source.content)}</div>
                        <div class="similarity-score">Similarity: ${(source.similarity * 100).toFixed(1)}%</div>
                    </div>
                `;
            });
            sourcesHtml += '</div>';
            
            // Display sources
            resultsDiv.innerHTML = sourcesHtml;
            
            // Step 2: Start the LLM reasoning process in the background
            const llmResponse = await fetch('/process_llm', {
                method: 'POST',
                headers: {
                    'Content-Type': 'application/json',
                    'Accept': 'application/json'
                },
                body: JSON.stringify({
                    request_id: requestId,
                    temperature: 0.7
                })
            });
            
            if (!llmResponse.ok) {
                const errorData = await llmResponse.json();
                throw new Error(errorData.detail || 'LLM processing failed');
            }
            
            const llmResult = await llmResponse.json();
            
            // Handle different response statuses
            if (llmResult.status === "completed") {
                // Hide thinking indicator
                thinkingIndicator.style.display = 'none';
                
                // Add thinking process and answer at the top
                const thinkingHtml = `
                    <div class="result-section">
                        <h2 class="section-title">Thinking Process</h2>
                        <div class="thinking-process">${processCitations(escapeHtml(llmResult.thinking_process))}</div>
                    </div>
                `;
                
                const answerHtml = `
                    <div class="result-section">
                        <h2 class="section-title">Final Answer</h2>
                        <div class="final-answer">${processCitations(escapeHtml(llmResult.answer)).replace(/\*\*(.*?)\*\*/g, '<strong>$1</strong>')}</div>
                    </div>
                `;
                
                // Insert before the sources section
                const sourcesSection = document.querySelector('.result-section');
                sourcesSection.insertAdjacentHTML('beforebegin', answerHtml + thinkingHtml);
            } else if (llmResult.status === "error") {
                // Handle error case
                thinkingIndicator.style.display = 'none';
                resultsDiv.insertAdjacentHTML('afterbegin', `
                    <div class="result-section" style="color: #e74c3c;">
                        <h2 class="section-title">Error</h2>
                        <p>${llmResult.error || "An error occurred while processing the query"}</p>
                    </div>
                `);
            } else {
                // Handle if status is still pending (should not happen with direct call)
                pollForResults(requestId);
            }
            
        } catch (error) {
            documentSearchIndicator.style.display = 'none';
            thinkingIndicator.style.display = 'none';
            resultsDiv.innerHTML = `
                <div class="result-section" style="color: #e74c3c;">
                    <h2 class="section-title">Error</h2>
                    <p>${error.message}</p>
                </div>
            `;
        }
    }
    
    // Function to poll for results if needed
    async function pollForResults(requestId) {
        const maxAttempts = 60; // 5 minutes at 5-second intervals
        let attempts = 0;
        const thinkingIndicator = document.getElementById('thinking-indicator');
        
        const poll = async () => {
            if (attempts >= maxAttempts) {
                thinkingIndicator.style.display = 'none';
                const errorHtml = `
                    <div class="result-section" style="color: #e74c3c;">
                        <h2 class="section-title">Timeout</h2>
                        <p>Request timed out after 5 minutes. Please try again.</p>
                    </div>
                `;
                document.getElementById('results').insertAdjacentHTML('afterbegin', errorHtml);
                return;
            }
            
            attempts++;
            
            try {
                const response = await fetch(`/llm_status/${requestId}`);
                if (!response.ok) {
                    throw new Error("Failed to retrieve status");
                }
                
                const result = await response.json();
                
                if (result.status === "completed") {
                    // Hide thinking indicator
                    thinkingIndicator.style.display = 'none';
                    
                    // Add thinking process and answer
                    const thinkingHtml = `
                        <div class="result-section">
                            <h2 class="section-title">Thinking Process</h2>
                            <div class="thinking-process">${processCitations(escapeHtml(result.thinking_process))}</div>
                        </div>
                    `;
                    
                    const answerHtml = `
                        <div class="result-section">
                            <h2 class="section-title">Final Answer</h2>
                            <div class="final-answer">${processCitations(escapeHtml(result.answer)).replace(/\*\*(.*?)\*\*/g, '<strong>$1</strong>')}</div>
                        </div>
                    `;
                    
                    // Insert at the beginning of results
                    const sourcesSection = document.querySelector('.result-section');
                    sourcesSection.insertAdjacentHTML('beforebegin', answerHtml + thinkingHtml);
                } else if (result.status === "error") {
                    // Handle error
                    thinkingIndicator.style.display = 'none';
                    const errorHtml = `
                        <div class="result-section" style="color: #e74c3c;">
                            <h2 class="section-title">Error</h2>
                            <p>${result.error || "An error occurred during processing"}</p>
                        </div>
                    `;
                    document.getElementById('results').insertAdjacentHTML('afterbegin', errorHtml);
                } else {
                    // Still processing, wait and try again
                    setTimeout(poll, 5000); // Check again after 5 seconds
                }
            } catch (error) {
                console.error("Error polling for results:", error);
                setTimeout(poll, 5000); // Try again after 5 seconds
            }
        };
        
        // Start polling
        poll();
    }
    </script>
</body>

serve_rag.yaml

name: serve-legal-rag

workdir: .

resources: accelerators: {L4:4, L40S:4} memory: 32+ ports: - 8000 any_of: - use_spot: true - use_spot: false

envs: EMBEDDING_MODEL_NAME: "Alibaba-NLP/gte-Qwen2-7B-instruct" GENERATION_MODEL_NAME: "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" VECTORDB_BUCKET_NAME: sky-rag-vectordb VECTORDB_BUCKET_ROOT: /vectordb

file_mounts: ${VECTORDB_BUCKET_ROOT}: name: ${VECTORDB_BUCKET_NAME} # this needs to be the same as in build_vectordb.yaml mode: MOUNT

setup: |

Install dependencies for RAG service

pip install numpy pandas sentence-transformers requests tqdm pip install fastapi uvicorn pydantic chromadb

Install dependencies for vLLM

pip install transformers==4.48.1 vllm==0.6.6.post1 hf_transfer

run: | HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download --local-dir /tmp/generation_model $GENERATION_MODEL_NAME HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download --local-dir /tmp/embedding_model $EMBEDDING_MODEL_NAME

Start vLLM generationservice in background

CUDA_VISIBLE_DEVICES=0,1,2 python -m vllm.entrypoints.openai.api_server
--host 0.0.0.0
--port 8002
--model /tmp/generation_model
--max-model-len 28816
--tensor-parallel-size 2
--task generate &

Wait for vLLM to start

echo "Waiting for vLLM service to be ready..." while ! curl -s http://localhost:8002/health > /dev/null; do sleep 5 echo "Still waiting for vLLM service..." done echo "vLLM service is ready!"

Start vLLM embeddings service in background

CUDA_VISIBLE_DEVICES=3 python -m vllm.entrypoints.openai.api_server
--host 0.0.0.0
--port 8003
--model /tmp/embedding_model
--max-model-len 4096
--task embed &

Wait for vLLM embeddings service to start

echo "Waiting for vLLM embeddings service to be ready..." while ! curl -s http://localhost:8003/health > /dev/null; do sleep 5 echo "Still waiting for vLLM embeddings service..." done echo "vLLM embeddings service is ready!"

Start RAG service

python scripts/serve_rag.py
--collection-name legal_docs
--persist-dir /vectordb/chroma
--generator-endpoint http://localhost:8002
--embed-endpoint http://localhost:8003

service: replicas: 1 readiness_probe: path: /health