基于阿里云与TensorFlow构建集成了OAuth 2.0身份认证的智能API滥用检测系统


我们基于Redis的滑动窗口计数器在周二凌晨3点彻底失效。攻击者没有使用传统的暴力请求,而是采用了一种低频、多端点轮询的策略,模拟大量真实用户的行为来抓取核心数据。每个IP、每个client_id的请求频率都在我们的阈值之下,但从全局看,这是一场精心策划的、缓慢而持续的数据窃取攻击。这次事件暴露了我们现有API网关安全策略的致命缺陷:它只能识别“莽夫”,无法识别“窃贼”。这成了我们接下来两个敏捷迭代周期的唯一目标——构建一个能够理解行为模式的智能滥用检测系统。

Sprint 0:技术选型与架构构想的挣扎

最初的方案是加固现有的规则引擎,增加更复杂的组合规则,例如“同一client_id在5分钟内访问超过20个不同商品详情页API端点即封禁”。但这很快被否决,因为我们意识到这会陷入一个“打地鼠”的游戏。我们今天添加了商品规则,明天攻击者就会转向用户资料端点。我们需要一个能自我学习和适应的系统。

机器学习是唯一的出路。我们的技术栈围绕着阿里云展开,自然而然地,我们评估了平台上的PAI服务。但考虑到团队对模型细节的掌控欲以及与现有监控体系集成的灵活性,我们最终决定自建一套基于TensorFlow的方案。

整个架构的决策过程充满了权衡:

  1. 数据源: API网关的访问日志是核心。我们需要一个持久化、可供复杂查询的数据存储。虽然Elasticsearch或ClickHouse在日志分析上更专业,但考虑到团队对MySQL的熟悉度、事务支持(用于未来可能的标签系统)以及阿里云RDS的成熟度,我们选择MySQL作为日志和特征的存储后端。这是一个在性能和维护成本之间做出的务实选择。
  2. 身份识别: 滥用行为的主体是client_iduser_id,这些信息由我们的OAuth 2.0授权服务器在令牌中提供。因此,API网关在记录日志时,必须能够解析JWT令牌并提取这些关键身份字段。这是将安全身份与行为分析连接起来的枢纽。
  3. 模型训练与服务: 训练任务可以离线进行,使用ECS实例或PAI-DSW都可以。但推理服务必须是低延迟、高可用的。阿里云函数计算(Function Compute, FC)成了最佳选择。它能提供毫秒级的弹性伸缩,而且我们只需为实际调用付费,非常适合这种实时但流量不均的推理场景。
  4. 开发流程: 我们严格遵循敏捷开发模式。第一个迭代周期(Sprint 1)的目标是搭建数据管道,实现基于SQL的特征工程,并训练出一个基线模型。第二个迭代周期(Sprint 2)则专注于模型的在线服务化和与API网关的集成。

最终确定的架构图如下:

graph TD
    subgraph "实时请求路径"
        Client[客户端] --> APIGW[阿里云API网关];
        APIGW -- "解析JWT, 提取client_id" --> AuthN[OAuth 2.0身份验证];
        APIGW -- "携带身份信息" --> InferenceFC[FC函数: 实时滥用检测];
        InferenceFC -- "允许/拒绝" --> APIGW;
        APIGW -- "若允许" --> Backend[后端服务];
    end

    subgraph "异步数据与训练路径"
        APIGW -- "访问日志(含client_id)" --> MNS[消息服务MNS];
        MNS --> LogProcessor[FC函数: 日志处理器];
        LogProcessor -- "结构化数据" --> RDS[MySQL RDS数据库];
        subgraph "离线训练环境 (ECS / PAI)"
            TrainingJob[定时训练任务] -- "读取数据" --> RDS;
            TrainingJob -- "特征工程 & TensorFlow训练" --> Model;
            Model[训练好的模型文件];
            TrainingJob -- "部署模型" --> NAS[文件存储NAS];
        end
        InferenceFC -- "加载模型" --> NAS;
    end

    Client -- "请求" --> AuthServer[OAuth 2.0授权服务器];
    AuthServer -- "返回JWT Token" --> Client;

Sprint 1:数据管道与MySQL的基石

一切智能分析都源于高质量的数据。我们的第一步是设计一个能够承载海量API访问日志并易于进行特征提取的MySQL表结构。

一个常见的错误是直接将原始日志文本塞进一个TEXT字段。这在查询时是灾难性的。我们设计的表结构必须是高度结构化的,并且为后续的查询性能做了深度优化。

-- api_access_logs.sql

CREATE TABLE `api_access_logs` (
  `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT,
  `request_id` VARCHAR(64) NOT NULL COMMENT '唯一请求ID',
  `client_id` VARCHAR(128) NOT NULL COMMENT 'OAuth 2.0客户端ID',
  `user_id` VARCHAR(128) DEFAULT NULL COMMENT '用户ID, 可能为空',
  `source_ip` VARCHAR(45) NOT NULL COMMENT '来源IP地址',
  `http_method` VARCHAR(10) NOT NULL COMMENT 'HTTP方法',
  `api_path` VARCHAR(255) NOT NULL COMMENT '请求路径, 不含Query参数',
  `http_status` SMALLINT UNSIGNED NOT NULL COMMENT 'HTTP状态码',
  `latency_ms` INT UNSIGNED NOT NULL COMMENT '请求处理延迟(毫秒)',
  `request_body_hash` VARCHAR(64) DEFAULT NULL COMMENT '请求体SHA256哈希, 用于识别重复请求',
  `user_agent` VARCHAR(512) DEFAULT NULL,
  `created_at` TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) COMMENT '请求时间, 精确到微秒',
  PRIMARY KEY (`id`),
  INDEX `idx_client_id_created_at` (`client_id`, `created_at`),
  INDEX `idx_user_id_created_at` (`user_id`, `created_at`),
  INDEX `idx_created_at` (`created_at`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='API访问日志表';

这里的关键设计考量:

  1. client_idcreated_at 的联合索引: 这是我们特征工程的核心。几乎所有的分析都是基于“某个客户端在某个时间窗口内的行为”,这个索引至关重要。
  2. TIMESTAMP(6): 微秒级精度对于分析短时间内的请求风暴非常有用。
  3. api_path: 我们只存储路径部分,查询参数会被剥离,以减少基数(cardinality)并方便聚合。
  4. request_body_hash: 这是一个高级特性,用于检测内容完全相同的重复提交型攻击。

数据写入的性能同样重要。API网关不能同步写入MySQL,这会引入不可接受的延迟。我们使用阿里云的消息服务MNS作为缓冲,API网关将日志投递到MNS队列,一个独立的FC函数消费这些消息并批量写入RDS。

这是日志处理函数的简化版核心逻辑:

# log_processor_fc.py
import os
import json
import logging
import mysql.connector
from mysql.connector import errorcode

# 从环境变量中安全地获取数据库凭证
DB_CONFIG = {
    'user': os.environ.get('DB_USER'),
    'password': os.environ.get('DB_PASSWORD'),
    'host': os.environ.get('DB_HOST'),
    'database': os.environ.get('DB_DATABASE'),
    'pool_name': 'mysql_pool',
    'pool_size': 5
}

# 数据库连接池在函数计算的全局作用域中初始化,以实现复用
try:
    cnx_pool = mysql.connector.pooling.MySQLConnectionPool(**DB_CONFIG)
except mysql.connector.Error as err:
    logging.error(f"Failed to create connection pool: {err}")
    cnx_pool = None

INSERT_STMT = (
    "INSERT INTO api_access_logs "
    "(request_id, client_id, user_id, source_ip, http_method, api_path, http_status, latency_ms, user_agent) "
    "VALUES (%(request_id)s, %(client_id)s, %(user_id)s, %(source_ip)s, %(http_method)s, %(api_path)s, %(http_status)s, %(latency_ms)s, %(user_agent)s)"
)

def handler(event, context):
    if not cnx_pool:
        raise RuntimeError("Database connection pool is not available.")
        
    try:
        # FC 的 MNS 触发器事件通常是 Base64 编码的
        event_data = json.loads(event)
        log_entries_str = event_data['data']
        # 假设日志是JSON数组格式的字符串
        log_entries = json.loads(log_entries_str)
    except (KeyError, json.JSONDecodeError) as e:
        logging.error(f"Failed to parse event data: {e}")
        return 'FAIL'

    conn = None
    cursor = None
    try:
        conn = cnx_pool.get_connection()
        cursor = conn.cursor()
        
        # 使用 executemany 进行批量插入,性能远高于单条插入
        cursor.executemany(INSERT_STMT, log_entries)
        conn.commit()
        
        logging.info(f"Successfully inserted {cursor.rowcount} log entries.")
        return 'OK'
    except mysql.connector.Error as err:
        logging.error(f"Database error: {err}")
        if conn and conn.is_connected():
            conn.rollback() # 发生错误时回滚
        return 'FAIL' # 返回FAIL,MNS会根据配置进行重试
    finally:
        if cursor:
            cursor.close()
        if conn and conn.is_connected():
            conn.close()

Sprint 2:TensorFlow模型与在线推理

数据就绪后,我们进入了核心阶段:特征工程和模型构建。我们的目标不是构建一个无比复杂的深度学习模型,而是一个能够快速验证、解释性强且推理性能高的基线模型。

特征工程

我们为每个client_id在过去5分钟的时间窗口内提取了以下特征:

  1. 请求总数 (total_requests): 最基本的频率指标。
  2. 独立端点数 (unique_endpoints): 访问的API路径种类数量。一个正常用户通常只访问少数几个端点,而爬虫会广泛扫描。
  3. 端点分布熵 (endpoint_entropy): 基于访问频率计算的信息熵。熵值越高,说明访问的端点越分散,行为越可疑。
  4. HTTP 4xx/5xx 错误率 (error_rate): 攻击者在探测时常会产生大量客户端或服务端错误。
  5. 请求间隔时间的标准差 (interval_stddev): 真实用户的请求间隔通常是无规律的,而机器行为的间隔可能非常稳定(标准差小)。
  6. UA多样性 (user_agent_diversity): 单一client_id下出现多种User-Agent的频率。

这些特征都可以通过SQL查询从api_access_logs表中计算得出。例如,计算请求总数和独立端点数:

SELECT
    COUNT(*) AS total_requests,
    COUNT(DISTINCT api_path) AS unique_endpoints
FROM
    api_access_logs
WHERE
    client_id = 'some_client_id' AND
    created_at >= NOW() - INTERVAL 5 MINUTE;

模型构建

我们选择了一个简单的多层感知机(MLP)模型。在真实项目中,模型的简单性往往意味着更好的可维护性和更低的推理延迟。

# model_training.py
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np

# 假设 feature_df 是从MySQL查询并计算好的特征DataFrame
# 列包括: ['total_requests', 'unique_endpoints', 'endpoint_entropy', ..., 'is_abuse']
# is_abuse 是我们人工标注或基于旧规则打标的标签 (0 或 1)

def build_and_train_model(feature_df):
    # 准备数据
    X = feature_df.drop('is_abuse', axis=1).values
    y = feature_df['is_abuse'].values

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

    # 特征标准化
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    
    # 保存scaler,推理时需要用
    # joblib.dump(scaler, 'scaler.joblib')

    # 定义模型结构
    model = models.Sequential([
        layers.Input(shape=(X_train_scaled.shape[1],)),
        layers.Dense(32, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.001)),
        layers.Dropout(0.3),
        layers.Dense(16, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.001)),
        layers.Dropout(0.3),
        layers.Dense(1, activation='sigmoid') # 二分类问题,输出一个0-1之间的概率值
    ])

    # 编译模型
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
                  loss='binary_crossentropy',
                  metrics=['accuracy', tf.keras.metrics.Precision(), tf.keras.metrics.Recall()])
    
    # 训练模型
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
    model.fit(X_train_scaled, y_train,
              epochs=100,
              batch_size=64,
              validation_split=0.2,
              callbacks=[early_stopping],
              verbose=2)

    # 评估模型
    loss, accuracy, precision, recall = model.evaluate(X_test_scaled, y_test, verbose=0)
    print(f"Test Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")

    # 保存模型为 SavedModel 格式
    model.save('abuse_detection_model')
    
    return model, scaler

在线推理服务

训练好的模型和scaler对象被部署到NAS(或OSS),以便函数计算实例能够访问。推理函数的职责是:接收来自API网关的请求元数据,实时计算特征,加载模型进行预测,并返回决策。

# inference_fc.py
import os
import json
import logging
import numpy as np
import tensorflow as tf
import joblib  # for loading the scaler
import mysql.connector

# --- 全局初始化 ---
# 在函数实例启动时执行一次
MODEL_PATH = '/mnt/auto/abuse_detection_model' # FC挂载NAS的路径
SCALER_PATH = '/mnt/auto/scaler.joblib'

try:
    model = tf.saved_model.load(MODEL_PATH)
    scaler = joblib.load(SCALER_PATH)
    # 获取具体的推理函数签名
    infer = model.signatures["serving_default"]
except Exception as e:
    logging.error(f"Failed to load model or scaler: {e}")
    model, scaler, infer = None, None, None

# 数据库连接池
DB_CONFIG = { ... } # 同上
try:
    cnx_pool = mysql.connector.pooling.MySQLConnectionPool(**DB_CONFIG)
except Exception as e:
    logging.error(f"Failed to create inference DB pool: {e}")
    cnx_pool = None

# --- 核心处理逻辑 ---
def calculate_features(client_id, db_connection):
    # 在这里实现通过SQL查询实时计算特征的逻辑
    # 这是一个非常耗时的操作,真实生产环境需要更高效的方案,例如预计算或使用内存数据库
    # 此处为演示简化版
    cursor = db_connection.cursor(dictionary=True)
    query = """
    SELECT
        COUNT(*) AS total_requests,
        COUNT(DISTINCT api_path) AS unique_endpoints
    FROM api_access_logs
    WHERE client_id = %s AND created_at >= NOW() - INTERVAL 5 MINUTE;
    """
    cursor.execute(query, (client_id,))
    result = cursor.fetchone()
    cursor.close()
    
    # ... 计算其他特征 ...
    # 为了演示,我们只返回部分特征,并用假数据填充
    features = np.array([
        result.get('total_requests', 0), 
        result.get('unique_endpoints', 0),
        # ... 其他特征值
        0.5, 0.0, 1.2, 1 
    ]).reshape(1, -1)
    
    return features


def handler(event, context):
    if not all([model, scaler, infer, cnx_pool]):
        # 模型或数据库未就绪,快速失败,放行请求,并记录严重错误
        logging.critical("Inference service is not ready. Allowing request by default.")
        return json.dumps({'decision': 'allow', 'reason': 'service_not_ready'})

    try:
        # API网关透传的参数
        body = json.loads(event['body'])
        client_id = body.get('client_id')
        if not client_id:
            raise ValueError("client_id not found in request")
    except (json.JSONDecodeError, ValueError) as e:
        logging.warning(f"Invalid request format: {e}")
        return json.dumps({'decision': 'allow', 'reason': 'invalid_request'})

    conn = None
    try:
        conn = cnx_pool.get_connection()
        # 1. 实时特征计算
        feature_vector = calculate_features(client_id, conn)

        # 2. 特征标准化
        feature_vector_scaled = scaler.transform(feature_vector)
        
        # 3. 模型推理
        tensor_input = tf.constant(feature_vector_scaled, dtype=tf.float32)
        prediction = infer(tensor_input)
        
        # `prediction` 是一个字典,key是输出层的名字
        output_tensor_name = list(prediction.keys())[0]
        abuse_probability = prediction[output_tensor_name].numpy()[0][0]

        # 4. 决策
        # 这里的阈值需要根据业务容忍度和模型在验证集上的表现来精细调整
        if abuse_probability > 0.8:
            decision = 'deny'
            logging.info(f"Denied request from client {client_id} with score {abuse_probability:.4f}")
        else:
            decision = 'allow'
            
        return json.dumps({'decision': decision, 'score': float(abuse_probability)})

    except Exception as e:
        # 任何异常都应默认放行,避免误杀,并记录详细日志
        logging.error(f"Error during inference for client {client_id}: {e}", exc_info=True)
        return json.dumps({'decision': 'allow', 'reason': 'inference_error'})
    finally:
        if conn and conn.is_connected():
            conn.close()

这个方案上线后,我们成功地识别并拦截了之前无法检测到的慢速抓取攻击。API网关会同步调用这个推理函数,根据返回的decision来决定是否将请求转发给后端服务。整个推理过程的P99延迟控制在50ms以内,对用户体验的影响微乎其微。

局限与未来的迭代方向

当前这套系统并非完美。在真实的生产环境中,它的局限性也很明显:

  1. 实时特征计算瓶颈: 直接在RDS上对全量日志进行聚合查询来计算特征,随着数据量增长,性能会急剧下降。更优化的架构应该引入流式计算框架(如Flink)或一个预聚合层,将特征实时计算并存储在Redis或Tair这样的内存数据库中,供推理函数高速查询。
  2. 模型迭代与反馈闭环: 目前的模型是离线训练的,无法适应攻击模式的快速变化。下一步计划是建立一个半自动化的标签系统,让安全运营团队能够标注误报和漏报的案例,形成反馈数据流,实现模型的定期自动重训和部署,即MLOps。
  3. 冷启动问题: 函数计算的冷启动可能会导致偶尔的延迟尖峰。对于对延迟极其敏感的核心业务,可能需要配置预留实例,但这会增加成本。
  4. 样本不均衡: 滥用行为在总流量中是极少数,这会导致严重的样本不均衡问题。在训练过程中,需要采用过采样(SMOTE)、欠采样或代价敏感学习等方法来处理。

尽管存在这些待办事项,但通过两个敏捷周期的快速迭代,我们从一个简单的规则引擎演进到了一个具备学习能力的智能防护系统。这个过程本身验证了敏捷开发在应对复杂、未知安全挑战时的价值。它允许我们先构建一个最小可行产品(MVP),然后基于真实的反馈和数据,逐步向更完善的架构演进。


  目录