构建基于 ActiveMQ Saga 模式的跨分片数据库事务协调器


当数据库写入成为瓶颈,水平分片(Sharding)几乎是唯一的出路。我们将用户表和订单表按 user_id 进行了拆分,部署到了独立的 AWS RDS 实例上。吞吐量问题解决了,但一个更棘手的问题浮出水面:原本在单体应用中由数据库ACID保证的事务被彻底破坏。一个典型的下单操作——“扣减用户余额”和“创建订单记录”——现在分布在两个不同的数据库分片上,无法再用一个简单的 @transaction.atomic 来包裹。

# The old, beautiful, but now broken world
from db import session

def create_order(user_id, amount, items):
    try:
        with session.begin():
            # This user is on shard 1
            user = session.query(User).filter_by(id=user_id).with_for_update().one()
            if user.balance < amount:
                raise InsufficientFundsError("Balance not enough")
            user.balance -= amount

            # This order is also on shard 1 (sharded by user_id)
            new_order = Order(user_id=user_id, amount=amount, items=items)
            session.add(new_order)
        return new_order.id
    except Exception as e:
        # DB transaction automatically rolls back
        log.error(f"Order creation failed for user {user_id}: {e}")
        raise

一旦 UserProductInventory 被分片到不同维度(比如用户按ID分片,库存按product_id分片),上述代码就会彻底失效。跨分片的原子性成了一个必须解决的架构难题。

初步构想与方案权衡

最初的讨论集中在两阶段提交(2PC)上。它能提供强一致性,但其固有的同步阻塞模型对系统性能和可用性是灾难性的。在2PC中,事务协调者(TM)需要锁定所有参与者(RM)的资源,直到所有参与者都准备就绪。任何一个参与者的网络抖动或宕机,都会导致整个事务长时间挂起,资源无法释放。在AWS这种随时可能发生实例重启的环境下,2PC的可用性风险太高。

我们最终选择了基于事件驱动的Saga模式。Saga的核心思想是将一个长事务拆分为一系列本地事务,每个本地事务都有一个对应的补偿操作。如果Saga中的任何一步失败,系统会按相反顺序执行补偿操作,从而在宏观上保证了数据的最终一致性。

Saga有两种主要实现方式:协同式(Choreography)和编排式(Orchestration)。协同式中,每个服务在完成本地事务后发布一个事件,其他服务监听这些事件并触发自己的本地事务。这种方式去中心化,但服务间的依赖关系会变得非常复杂且难以追踪,对于超过三步的事务,调试和维护会成为噩梦。

因此,我们决定采用编排式Saga。引入一个中心化的“Saga协调器”(Orchestrator),由它来负责向各个参与方发送命令,并根据参与方的响应来决定下一步是继续执行还是触发补偿流程。这种模式下,事务逻辑集中在协调器中,流程清晰,易于监控和管理。

消息队列是实现Saga编排器的理想工具。它提供了异步通信、削峰填谷、以及最重要的——可靠消息传递。在AWS上,我们可以在Amazon SQS/SNS和Amazon MQ之间选择。考虑到团队对JMS规范的熟悉度以及ActiveMQ对事务性消息的良好支持,我们选择了基于ActiveMQ的Amazon MQ作为底层消息代理。它既有云服务的便利性,又保留了我们需要的控制力。

架构设计与实现细节

我们的目标是构建一个通用的、可配置的Saga协调器服务。它本身是一个Flask应用,负责定义Saga流程、跟踪状态并与ActiveMQ交互。

sequenceDiagram
    participant Client
    participant OrderService (Flask API)
    participant SagaOrchestrator (Flask/ActiveMQ Listener)
    participant AccountService (Participant)
    participant InventoryService (Participant)
    participant SagaStateDB

    Client->>+OrderService: POST /orders (user_id, product_id, amount)
    OrderService->>+SagaOrchestrator: start_saga("CREATE_ORDER", payload)
    SagaOrchestrator->>+SagaStateDB: CREATE SagaInstance(status=PENDING)
    Note over SagaOrchestrator: Saga state machine starts
    SagaOrchestrator->>AccountService: Send Command: "account.debit.command"
    activate AccountService
    Note over AccountService: Performs local DB transaction on user shard
    AccountService-->>SagaOrchestrator: Send Reply: "saga.reply.queue" (status=SUCCESS)
    deactivate AccountService
    SagaOrchestrator->>SagaStateDB: UPDATE SagaInstance(status=DEBITED)
    SagaOrchestrator->>InventoryService: Send Command: "inventory.reserve.command"
    activate InventoryService
    Note over InventoryService: Fails local transaction (e.g., out of stock)
    InventoryService-->>SagaOrchestrator: Send Reply: "saga.reply.queue" (status=FAILURE, reason="Out of Stock")
    deactivate InventoryService
    SagaOrchestrator->>SagaStateDB: UPDATE SagaInstance(status=COMPENSATING)
    Note over SagaOrchestrator: Compensation logic triggered
    SagaOrchestrator->>AccountService: Send Command: "account.credit.command" (Compensation)
    activate AccountService
    Note over AccountService: Performs compensating local transaction
    AccountService-->>SagaOrchestrator: Send Reply: "saga.reply.queue" (status=COMPENSATION_SUCCESS)
    deactivate AccountService
    SagaOrchestrator->>SagaStateDB: UPDATE SagaInstance(status=FAILED_ROLLED_BACK)
    SagaOrchestrator-->>-OrderService: saga_result(status=FAILURE)
    OrderService-->>-Client: HTTP 500 / 400 (Order Failed)

1. Saga协调器的核心逻辑

协调器是整个系统的中枢。我们用Python的stomp.py库与ActiveMQ通信,并用一个简单的数据库(这里用非分片的PostgreSQL on RDS)来持久化Saga实例的状态。

Saga 定义:
我们用一个Python字典来定义Saga流程。每个步骤包含执行任务(task)和补偿任务(compensation)。

# orchestrator/sagas.py

SAGA_DEFINITIONS = {
    "CREATE_ORDER_SAGA": {
        "steps": [
            {
                "name": "DEBIT_ACCOUNT",
                "task": "account.debit.command",
                "compensation": "account.credit.command"
            },
            {
                "name": "RESERVE_INVENTORY",
                "task": "inventory.reserve.command",
                "compensation": "inventory.release.command"
            },
            {
                "name": "CREATE_ORDER_RECORD",
                "task": "order.create.command",
                "compensation": "order.cancel.command"
            }
        ]
    }
}

协调器主服务 (orchestrator/service.py):
这个服务负责监听回复队列,并根据Saga状态机驱动流程。

# orchestrator/service.py
import stomp
import json
import uuid
import time
import logging
from threading import Thread
from config import settings
from db import Session
from models import SagaInstance
from sagas import SAGA_DEFINITIONS

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class SagaOrchestrator(stomp.ConnectionListener):
    def __init__(self, conn):
        self.conn = conn
        self.correlation_map = {} # In-memory map for simplicity, production needs Redis/DB

    def on_error(self, frame):
        logging.error(f'Received an error "{frame.body}"')

    def on_message(self, frame):
        try:
            message = json.loads(frame.body)
            correlation_id = frame.headers.get('correlation-id')
            saga_id = message.get('saga_id')
            status = message.get('status')

            logging.info(f"Received reply for saga {saga_id} (correlation_id: {correlation_id}): {status}")

            session = Session()
            saga_instance = session.query(SagaInstance).filter_by(id=saga_id).one_or_none()
            if not saga_instance:
                logging.error(f"Saga instance {saga_id} not found!")
                # Here we should probably send to a dead-letter-queue
                return

            if status == 'SUCCESS':
                self.handle_success(saga_instance, session)
            else:
                self.handle_failure(saga_instance, session, message.get('reason'))
        
        except Exception as e:
            logging.exception(f"Error processing message: {frame.body}")
            # Do not acknowledge the message, let the broker redeliver
        finally:
            if 'session' in locals() and session.is_active:
                session.close()

    def handle_success(self, saga_instance, session):
        saga_def = SAGA_DEFINITIONS[saga_instance.saga_name]
        saga_instance.current_step += 1

        if saga_instance.current_step >= len(saga_def['steps']):
            saga_instance.status = 'COMPLETED'
            logging.info(f"Saga {saga_instance.id} completed successfully.")
        else:
            self.execute_next_step(saga_instance, saga_def)
        
        session.commit()

    def handle_failure(self, saga_instance, session, reason):
        saga_instance.status = 'COMPENSATING'
        saga_instance.failure_reason = reason
        logging.warning(f"Saga {saga_instance.id} failed at step {saga_instance.current_step}. Reason: {reason}. Starting compensation.")
        self.execute_compensation(saga_instance, session)
        session.commit()

    def execute_next_step(self, saga_instance, saga_def):
        step_def = saga_def['steps'][saga_instance.current_step]
        logging.info(f"Saga {saga_instance.id}: executing step {step_def['name']}")
        
        correlation_id = str(uuid.uuid4())
        command = {
            "saga_id": saga_instance.id,
            "payload": saga_instance.payload
        }
        
        self.conn.send(
            body=json.dumps(command),
            destination=f'/queue/{step_def["task"]}',
            headers={
                'correlation-id': correlation_id,
                'reply-to': settings.SAGA_REPLY_QUEUE
            }
        )

    def execute_compensation(self, saga_instance, session):
        saga_def = SAGA_DEFINITIONS[saga_instance.saga_name]
        # Compensate backwards from the last completed step
        step_to_compensate_idx = saga_instance.current_step - 1

        if step_to_compensate_idx < 0:
            saga_instance.status = 'FAILED'
            logging.info(f"Saga {saga_instance.id} failed before any step completed. No compensation needed.")
            session.commit()
            return
        
        step_def = saga_def['steps'][step_to_compensate_idx]
        compensation_task = step_def.get('compensation')

        if compensation_task:
            logging.info(f"Saga {saga_instance.id}: executing compensation for step {step_def['name']}")
            correlation_id = str(uuid.uuid4())
            command = {
                "saga_id": saga_instance.id,
                "payload": saga_instance.payload
            }
            # In a real system, compensation replies should be handled to ensure they succeed
            self.conn.send(
                body=json.dumps(command),
                destination=f'/queue/{compensation_task}',
                headers={'correlation-id': correlation_id}
            )
            # A more robust system would track compensation status per step
            saga_instance.current_step -= 1
            if saga_instance.current_step >= 0:
                 self.execute_compensation(saga_instance, session)
            else:
                 saga_instance.status = 'ROLLED_BACK'
                 logging.info(f"Saga {saga_instance.id} fully compensated and rolled back.")

        else:
             logging.warning(f"No compensation task defined for step {step_def['name']}")
             saga_instance.current_step -= 1
             if saga_instance.current_step >= 0:
                 self.execute_compensation(saga_instance, session)
             else:
                 saga_instance.status = 'ROLLED_BACK'
                 logging.info(f"Saga {saga_instance.id} fully compensated and rolled back.")

def start_saga(saga_name, payload):
    session = Session()
    try:
        saga_def = SAGA_DEFINITIONS.get(saga_name)
        if not saga_def:
            raise ValueError(f"Saga definition for {saga_name} not found.")

        instance = SagaInstance(
            saga_name=saga_name,
            status='STARTED',
            payload=payload,
            current_step=0
        )
        session.add(instance)
        session.commit()
        
        # Start the first step
        orchestrator.execute_next_step(instance, saga_def)
        session.commit()
        return instance.id
    finally:
        session.close()

# Main application setup
# ... stomp connection setup ...
# conn.set_listener('', orchestrator)
# conn.subscribe(destination=settings.SAGA_REPLY_QUEUE, id=1, ack='auto')

关键点:

  1. 状态持久化: 每次状态转换(执行下一步、开始补偿)都必须先更新数据库中的SagaInstance表。即使协调器崩溃重启,它也能从数据库中恢复状态,决定下一步操作(或重新发送消息)。
  2. 消息头: correlation-id用于追踪请求和响应,reply-to告诉参与方将结果发送到哪个队列。
  3. 幂等性: 补偿操作和部分业务操作必须是幂等的。如果协调器在发送补偿命令后崩溃,重启后可能会再次发送。参与方需要能处理重复的消息而不产生副作用。

2. 参与方服务的设计

参与方服务是简单的Flask应用,它只关心自己的本地事务。以AccountService为例:

# account_service/app.py
import stomp
import json
import logging
from flask import Flask
from threading import Thread
from config import settings
from db import get_sharded_session  # This function connects to the correct DB shard

app = Flask(__name__)

class AccountServiceListener(stomp.ConnectionListener):
    def __init__(self, conn):
        self.conn = conn

    def on_message(self, frame):
        correlation_id = frame.headers.get('correlation-id')
        reply_to = frame.headers.get('reply-to')
        
        # The key is to start a local transaction for message processing
        # This emulates a "transacted session" in JMS
        session = None
        try:
            message = json.loads(frame.body)
            saga_id = message['saga_id']
            user_id = message['payload']['user_id']
            amount = message['payload']['amount']
            
            session = get_sharded_session(user_id)
            
            # Message routing based on destination queue
            queue_name = frame.headers.get('destination').split('/')[-1]
            if queue_name == 'account.debit.command':
                self.handle_debit(session, saga_id, user_id, amount, reply_to, correlation_id)
            elif queue_name == 'account.credit.command':
                self.handle_credit(session, saga_id, user_id, amount, correlation_id)
            
            session.commit()
            self.conn.ack(frame.headers['message-id'], frame.headers['subscription'])

        except Exception as e:
            logging.exception("Failed to process account command")
            if session:
                session.rollback()
            # Negative Acknowledge - let the broker redeliver
            self.conn.nack(frame.headers['message-id'], frame.headers['subscription'])
            
            # Also send a failure reply to unblock the saga
            if reply_to:
                reply = {'saga_id': saga_id, 'status': 'FAILURE', 'reason': str(e)}
                self.conn.send(body=json.dumps(reply), destination=reply_to, headers={'correlation-id': correlation_id})
        finally:
            if session:
                session.close()

    def handle_debit(self, session, saga_id, user_id, amount, reply_to, correlation_id):
        # A real implementation needs an idempotency check using saga_id + step_name
        user = session.query(User).filter_by(id=user_id).with_for_update().one()
        if user.balance < amount:
            raise ValueError("Insufficient funds")
        user.balance -= amount
        
        reply = {'saga_id': saga_id, 'status': 'SUCCESS'}
        self.conn.send(body=json.dumps(reply), destination=reply_to, headers={'correlation-id': correlation_id})
        logging.info(f"Debited {amount} from user {user_id} for saga {saga_id}")

    def handle_credit(self, session, saga_id, user_id, amount, correlation_id):
        # Compensation logic must be idempotent and should not fail
        user = session.query(User).filter_by(id=user_id).with_for_update().one()
        user.balance += amount
        logging.info(f"COMPENSATION: Credited {amount} to user {user_id} for saga {saga_id}")
        # Compensation might also send a reply for tracking purposes

def run_listener():
    conn = stomp.Connection([(settings.MQ_HOST, settings.MQ_PORT)])
    conn.set_listener('', AccountServiceListener(conn))
    conn.connect(settings.MQ_USER, settings.MQ_PASS, wait=True)
    # Use client-individual acknowledgment for transactional behavior
    conn.subscribe(destination='/queue/account.debit.command', id='debit_sub', ack='client-individual')
    conn.subscribe(destination='/queue/account.credit.command', id='credit_sub', ack='client-individual')
    while True:
        time.sleep(10)

if __name__ == '__main__':
    listener_thread = Thread(target=run_listener)
    listener_thread.daemon = True
    listener_thread.start()
    app.run(host='0.0.0.0', port=5001)

参与方的关键原则:

  1. 本地事务: 所有的数据库操作都封装在单个本地事务中。
  2. 幂等接收器: handle_credit (补偿) 必须被设计为幂等的。一个简单的方法是创建一个processed_saga_steps表,在执行操作前检查 (saga_id, step_name) 是否已处理。
  3. 原子性的消息消费: 使用 ack='client-individual' 模式。只有当本地数据库事务成功 commit 后,才向ActiveMQ发送ack确认消息消费。如果处理失败,则 rollback 数据库事务并发送 nack,让消息队列稍后重试。这确保了数据库操作和消息消费的原子性。
  4. 明确的失败通知: 如果业务逻辑失败(如余额不足),服务必须显式地向reply-to队列发送失败消息,以便协调器能够立即启动补偿流程,而不是等待超时。

部署在 AWS 上的考量

  • Amazon MQ: 我们选择Amazon MQ而非自建ActiveMQ on EC2,是为了减少运维负担。需要注意的是,要选择Active/Standby模式的Broker以获得高可用性。网络配置(安全组)需要精确控制,只允许必要的服务(协调器、参与方)访问Broker的端口。
  • AWS RDS for Shards: 每个分片是一个独立的RDS实例。get_sharded_session函数的实现会基于user_id(或其他分片键)的哈希值动态选择连接字符串。这部分逻辑必须高效且无误。
  • IAM Roles: 所有服务(Flask应用部署在EC2或ECS上)都应使用IAM角色来获取访问其他AWS服务(如RDS, Amazon MQ)的凭证,而不是硬编码AK/SK。
  • 日志与监控: 所有服务的日志都应该输出为结构化的JSON格式,并推送到Amazon CloudWatch Logs。在日志中包含saga_idcorrelation_id至关重要,这使得我们可以通过CloudWatch Logs Insights查询一个完整Saga事务的生命周期,即使它跨越了多个服务和数据库。

遗留问题与未来迭代路径

这个基于Saga的协调器解决了跨分片事务的核心问题,但它远非完美。在真实生产环境中,还有几个方面需要加固。

首先,协调器本身存在单点故障风险。虽然其状态是持久化的,但如果协调器进程长时间宕机,所有进行中的Saga都会被卡住。一个高可用的实现需要将协调器部署为Active/Standby模式,或者使用类似Kubernetes的编排工具来保证其总有实例在运行。

其次,可观测性仍然是一个挑战。我们现在依赖于基于ID的日志聚合查询,但这在定位性能瓶颈或复杂失败场景时效率不高。引入分布式追踪系统(如 OpenTelemetry),将trace_id从Saga的起点一直传递到所有参与方,将极大地提升我们对系统行为的理解。

最后,Saga模式增加了业务逻辑的复杂性。开发者不仅要编写核心业务逻辑,还要为其编写补偿逻辑。补偿逻辑的测试尤其困难,需要通过故障注入或混沌工程来验证其健壮性。当前方案的测试依赖于单元测试和端到端集成测试,但缺乏对网络分区、服务超时等真实世界故障的模拟。未来的迭代方向是构建一个测试框架,可以方便地模拟参与方失败,以确保补偿流程按预期执行。


  目录