使用 FastAPI 将模型封装为在线的 API 服务
从离线模型到在线API服务的解决方案之一
目标
将本地/离线模型封装为可调用的 HTTP API 服务,统一输入输出契约,支持训练、预测、状态与模型管理,便于前端/平台集成。
设计要点
- 数据契约优先:明确输入/输出字段、单位、长度、边界与错误提示
- 模块化分层:
router(API 层)/service(业务编排)/core(算法与 I/O) - 可观测性:健康检查、训练进度、验证指标、模型元数据
- 可持续:模型持久化(权重+scaler+metadata)、模型列表/删除、损失曲线
快速模板
1. 项目结构
app/
main.py # 启动入口(聚合 routers)
api/
__init__.py
routers/
predict.py # 预测路由
train.py # 训练路由
health.py # 健康检查
service/
predictor.py # 预测编排
trainer.py # 训练编排
core/
model.py # 模型定义/加载/保存
io.py # 标准化/反标准化/序列构造
2. 依赖声明(示例)
pip install fastapi uvicorn pydantic numpy torch scikit-learn joblib
3. 关键代码片段
main.py
from fastapi import FastAPI
from api.routers.predict import router as predict_router
from api.routers.train import router as train_router
from api.routers.health import router as health_router
app = FastAPI(title="模型服务", version="1.0.0")
app.include_router(train_router, prefix="/train")
app.include_router(predict_router, prefix="/predict")
app.include_router(health_router, prefix="")
预测路由(最小示例)
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from typing import List
from api.service.predictor import predict_batch
router = APIRouter(tags=["预测"])
class PredictIn(BaseModel):
future_features: List[List[float]]
recent_historical_features: List[List[float]]
recent_historical_target: List[float]
model_id: str | None = None
@router.post("", summary="批量预测")
async def do_predict(payload: PredictIn):
try:
return await predict_batch(payload)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
训练路由(异步)
from fastapi import APIRouter, BackgroundTasks
from pydantic import BaseModel
from api.service.trainer import start_training
router = APIRouter(tags=["训练"])
class TrainIn(BaseModel):
historical_target: list[float]
historical_features: list[list[float]]
sequence_length: int = 30
@router.post("", summary="启动训练(异步)")
async def do_train(payload: TrainIn, background: BackgroundTasks):
task_id = await start_training(payload, background)
return {"task_id": task_id, "status": "accepted"}
健康检查
from fastapi import APIRouter
router = APIRouter(tags=["健康检查"])
@router.get("/health")
async def health():
return {"status": "healthy"}
关键
- 输入校验:长度/维度一致性、数值范围、缺失值处理
- 标准化一致:训练与预测使用同一 scaler(持久化与加载)
- 元数据记录:config、训练损失、验证指标、创建时间、描述
- 预测策略:滑窗迭代、滞后特征构造与验证
- 错误响应:HTTP 4xx(输入错误)、5xx(系统错误),信息清晰
常见问题(FAQ)
- Q: 输入维度对不上怎么办?
- A: 在训练时把
input_size写入 metadata,预测时校验;不符直接 400。
- A: 在训练时把
- Q: 预测慢如何优化?
- A: 复用 scaler 与模型;使用批量推理;CPU 多进程或启用 GPU。
- Q: 如何支持多模型?
- A: 用
models/uuid/目录结构存放模型与元数据,提供/models列表与选择。
- A: 用