Connect SQLAlchemy to Xata

Learn how to connect your Python application using SQLAlchemy to Xata's PostgreSQL platform. Get started with SQLAlchemy and PostgreSQL for robust data operations.

Prerequisites

  • Xata account and project setup
  • Python 3.8+ installed
  • Basic knowledge of Python and SQLAlchemy

Setup Xata Database

First, set up your Xata database with the common e-commerce dataset:

  1. Create a project and branch in the Xata console
  2. Navigate to the Queries tab in your branch
  3. Run the following SQL commands to create the initial schema:
CREATE TABLE products (
  id SERIAL PRIMARY KEY,
  name TEXT NOT NULL,
  price NUMERIC(7,2) NOT NULL,
  rating INTEGER
);

CREATE TABLE orders (
  id SERIAL PRIMARY KEY,
  created_at TIMESTAMPTZ DEFAULT NOW()
);

CREATE TABLE order_items (
  order_id INTEGER REFERENCES orders(id),
  product_id INTEGER REFERENCES products(id),
  qty INTEGER NOT NULL,
  PRIMARY KEY (order_id, product_id)
);

Create Project

Create a new Python project:

mkdir my-xata-app
cd my-xata-app
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

Initialize Xata Project

Initialize your Xata project configuration:

xata init

This will create a .xata directory with your project configuration.

Store Your Credentials

Create a .env file in your project root to store your Xata connection string:

# .env
DATABASE_URL="postgresql://username:password@host:port/database"

Get your connection string from the Xata console or CLI:

xata branch url

Important: Never commit your .env file to version control. Add it to your .gitignore file.

Install Dependencies

Install SQLAlchemy and PostgreSQL driver:

pip install sqlalchemy psycopg2-binary python-decouple

Configure Database Connection

Create a database configuration file:

# db/config.py
import os
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from decouple import config

DATABASE_URL = config('DATABASE_URL')

if not DATABASE_URL:
    raise ValueError("DATABASE_URL environment variable is required")

engine = create_engine(DATABASE_URL)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

def get_db():
    db = SessionLocal()
    try:
        yield db
    finally:
        db.close()

Define Models

Create your SQLAlchemy models:

# models.py
from sqlalchemy import Column, Integer, String, Numeric, DateTime, ForeignKey, Table
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from db.config import Base

# Association table for many-to-many relationship
order_items = Table(
    'order_items',
    Base.metadata,
    Column('order_id', Integer, ForeignKey('orders.id'), primary_key=True),
    Column('product_id', Integer, ForeignKey('products.id'), primary_key=True),
    Column('qty', Integer, nullable=False)
)

class Product(Base):
    __tablename__ = "products"

    id = Column(Integer, primary_key=True, index=True)
    name = Column(String, nullable=False)
    price = Column(Numeric(7, 2), nullable=False)
    rating = Column(Integer)

    # Relationship
    order_items = relationship("OrderItem", back_populates="product")

    def __repr__(self):
        return f"<Product(id={self.id}, name='{self.name}', price={self.price})>"

class Order(Base):
    __tablename__ = "orders"

    id = Column(Integer, primary_key=True, index=True)
    created_at = Column(DateTime(timezone=True), server_default=func.now())

    # Relationship
    order_items = relationship("OrderItem", back_populates="order")

    def __repr__(self):
        return f"<Order(id={self.id}, created_at={self.created_at})>"

class OrderItem(Base):
    __tablename__ = "order_items"

    order_id = Column(Integer, ForeignKey("orders.id"), primary_key=True)
    product_id = Column(Integer, ForeignKey("products.id"), primary_key=True)
    qty = Column(Integer, nullable=False)

    # Relationships
    order = relationship("Order", back_populates="order_items")
    product = relationship("Product", back_populates="order_items")

    def __repr__(self):
        return f"<OrderItem(order_id={self.order_id}, product_id={self.product_id}, qty={self.qty})>"

Create Database Operations

Create a service layer for database operations:

# services.py
from sqlalchemy.orm import Session
from sqlalchemy import select, func
from typing import List, Optional
from models import Product, Order, OrderItem
from decimal import Decimal

class ProductService:
    def __init__(self, db: Session):
        self.db = db

    def get_all_products(self) -> List[Product]:
        return self.db.query(Product).all()

    def get_product_by_id(self, product_id: int) -> Optional[Product]:
        return self.db.query(Product).filter(Product.id == product_id).first()

    def create_product(self, name: str, price: Decimal, rating: Optional[int] = None) -> Product:
        product = Product(name=name, price=price, rating=rating)
        self.db.add(product)
        self.db.commit()
        self.db.refresh(product)
        return product

    def update_product(self, product_id: int, name: str = None, price: Decimal = None, rating: int = None) -> Optional[Product]:
        product = self.get_product_by_id(product_id)
        if product:
            if name is not None:
                product.name = name
            if price is not None:
                product.price = price
            if rating is not None:
                product.rating = rating
            self.db.commit()
            self.db.refresh(product)
        return product

    def delete_product(self, product_id: int) -> bool:
        product = self.get_product_by_id(product_id)
        if product:
            self.db.delete(product)
            self.db.commit()
            return True
        return False

    def get_products_by_rating(self, rating: int) -> List[Product]:
        return self.db.query(Product).filter(Product.rating == rating).all()

    def get_products_by_price_range(self, min_price: Decimal, max_price: Decimal) -> List[Product]:
        return self.db.query(Product).filter(
            Product.price >= min_price,
            Product.price <= max_price
        ).all()

class OrderService:
    def __init__(self, db: Session):
        self.db = db

    def get_all_orders(self) -> List[Order]:
        return self.db.query(Order).all()

    def get_order_by_id(self, order_id: int) -> Optional[Order]:
        return self.db.query(Order).filter(Order.id == order_id).first()

    def create_order(self) -> Order:
        order = Order()
        self.db.add(order)
        self.db.commit()
        self.db.refresh(order)
        return order

    def add_item_to_order(self, order_id: int, product_id: int, qty: int) -> Optional[OrderItem]:
        order = self.get_order_by_id(order_id)
        product = self.db.query(Product).filter(Product.id == product_id).first()
        
        if order and product:
            order_item = OrderItem(order_id=order_id, product_id=product_id, qty=qty)
            self.db.add(order_item)
            self.db.commit()
            self.db.refresh(order_item)
            return order_item
        return None

    def get_order_with_items(self, order_id: int) -> Optional[Order]:
        return self.db.query(Order).filter(Order.id == order_id).first()

    def get_orders_with_product_details(self) -> List[dict]:
        """Get orders with product details and calculated totals"""
        query = self.db.query(
            Order.id.label('order_id'),
            Order.created_at,
            Product.name.label('product_name'),
            Product.price,
            OrderItem.qty,
            (Product.price * OrderItem.qty).label('total')
        ).join(OrderItem, Order.id == OrderItem.order_id)\
         .join(Product, OrderItem.product_id == Product.id)
        
        return [dict(row) for row in query.all()]

Create Sample Application

Create a simple application to demonstrate the functionality:

# app.py
from decimal import Decimal
from db.config import SessionLocal
from services import ProductService, OrderService
from models import Base, engine

def main():
    # Create tables (in production, use migrations)
    Base.metadata.create_all(bind=engine)
    
    # Get database session
    db = SessionLocal()
    
    try:
        # Initialize services
        product_service = ProductService(db)
        order_service = OrderService(db)
        
        # Create sample products
        print("Creating sample products...")
        products = [
            product_service.create_product("Wireless Headphones", Decimal("99.99"), 5),
            product_service.create_product("Smartphone", Decimal("699.99"), 4),
            product_service.create_product("Laptop", Decimal("1299.99"), 5),
            product_service.create_product("Coffee Maker", Decimal("89.99"), 4),
        ]
        
        print(f"Created {len(products)} products")
        
        # Display all products
        print("\nAll products:")
        all_products = product_service.get_all_products()
        for product in all_products:
            print(f"- {product.name}: ${product.price} (Rating: {product.rating}/5)")
        
        # Get products by rating
        print("\nTop-rated products (5 stars):")
        top_products = product_service.get_products_by_rating(5)
        for product in top_products:
            print(f"- {product.name}: ${product.price}")
        
        # Create an order with items
        print("\nCreating an order...")
        order = order_service.create_order()
        print(f"Created order #{order.id}")
        
        # Add items to order
        order_service.add_item_to_order(order.id, products[0].id, 2)  # 2 headphones
        order_service.add_item_to_order(order.id, products[2].id, 1)  # 1 laptop
        
        # Get order details with products
        print("\nOrder details:")
        order_details = order_service.get_orders_with_product_details()
        for detail in order_details:
            print(f"Order #{detail['order_id']}: {detail['product_name']} x{detail['qty']} = ${detail['total']}")
        
        # Update a product
        print("\nUpdating product...")
        updated_product = product_service.update_product(
            products[0].id, 
            name="Premium Wireless Headphones",
            price=Decimal("129.99")
        )
        print(f"Updated: {updated_product.name} - ${updated_product.price}")
        
    finally:
        db.close()

if __name__ == "__main__":
    main()

Create Requirements File

Create a requirements.txt file:

# requirements.txt
sqlalchemy>=2.0.0
psycopg2-binary>=2.9.0
python-decouple>=3.8

Run the Application

Install dependencies and run the application:

pip install -r requirements.txt
python app.py

Create FastAPI Integration (Optional)

For a web API, you can integrate with FastAPI:

# api.py
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List
from decimal import Decimal
from pydantic import BaseModel

from db.config import get_db
from services import ProductService, OrderService
from models import Product, Order

app = FastAPI(title="Xata SQLAlchemy API")

# Pydantic models for API
class ProductCreate(BaseModel):
    name: str
    price: Decimal
    rating: int = None

class ProductResponse(BaseModel):
    id: int
    name: str
    price: Decimal
    rating: int = None

    class Config:
        from_attributes = True

@app.get("/products", response_model=List[ProductResponse])
def get_products(db: Session = Depends(get_db)):
    service = ProductService(db)
    return service.get_all_products()

@app.get("/products/{product_id}", response_model=ProductResponse)
def get_product(product_id: int, db: Session = Depends(get_db)):
    service = ProductService(db)
    product = service.get_product_by_id(product_id)
    if not product:
        raise HTTPException(status_code=404, detail="Product not found")
    return product

@app.post("/products", response_model=ProductResponse)
def create_product(product: ProductCreate, db: Session = Depends(get_db)):
    service = ProductService(db)
    return service.create_product(product.name, product.price, product.rating)

@app.get("/products/rating/{rating}", response_model=List[ProductResponse])
def get_products_by_rating(rating: int, db: Session = Depends(get_db)):
    service = ProductService(db)
    return service.get_products_by_rating(rating)

To run the FastAPI version:

pip install fastapi uvicorn
uvicorn api:app --reload

Next Steps