Middleware chain e extractors no Axum

1. Introdução ao ecossistema de middlewares e extractors no Axum

No desenvolvimento de servidores HTTP com Rust, o Axum se destaca por sua arquitetura baseada em tower::Service e tower::Layer. Middlewares e extractors são dois conceitos fundamentais nesse ecossistema, mas com papéis distintos.

Middlewares (ou layers) são componentes que interceptam requisições e respostas, permitindo adicionar funcionalidades transversais como logging, autenticação, compressão e CORS. Eles atuam como uma cadeia de processamento (middleware chain) que envolve o handler final.

Extractors são tipos que implementam a trait FromRequest ou FromRequestParts, permitindo extrair dados estruturados da requisição HTTP — como parâmetros de rota, corpo JSON, headers ou estado compartilhado.

O Axum utiliza a crate tower como base, onde cada middleware é um Layer que transforma um Service em outro Service. Essa arquitetura permite composição elegante e reuso de componentes.

2. Criando e aplicando middlewares personalizados

Vamos criar um middleware simples de logging que registra o método, caminho e duração de cada requisição:

use axum::{
    body::Body,
    http::Request,
    response::Response,
    Router,
    middleware::{self, Next},
};
use std::time::Instant;
use tower::ServiceBuilder;
use tower_http::trace::TraceLayer;

async fn logging_middleware(
    req: Request<Body>,
    next: Next,
) -> Result<Response, axum::http::StatusCode> {
    let start = Instant::now();
    let method = req.method().clone();
    let uri = req.uri().clone();

    let response = next.run(req).await;

    let elapsed = start.elapsed();
    println!("{} {} -> {} ({}ms)", method, uri, response.status(), elapsed.as_millis());

    Ok(response)
}

async fn hello_handler() -> &'static str {
    "Hello, World!"
}

#[tokio::main]
async fn main() {
    let app = Router::new()
        .route("/", axum::routing::get(hello_handler))
        .layer(middleware::from_fn(logging_middleware));

    axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

Para encadear múltiplos middlewares, usamos tower::ServiceBuilder:

use tower::ServiceBuilder;
use tower_http::{compression::CompressionLayer, cors::CorsLayer};

let app = Router::new()
    .route("/", axum::routing::get(hello_handler))
    .layer(
        ServiceBuilder::new()
            .layer(TraceLayer::new_for_http())
            .layer(CompressionLayer::new())
            .layer(CorsLayer::permissive())
            .layer(middleware::from_fn(logging_middleware))
    );

3. Middlewares nativos do Axum e da Tower ecosystem

O ecossistema tower-http oferece diversos middlewares prontos para uso:

use tower_http::{
    cors::CorsLayer,
    compression::CompressionLayer,
    timeout::TimeoutLayer,
    set_header::SetResponseHeaderLayer,
    trace::TraceLayer,
};
use std::time::Duration;
use axum::http::HeaderValue;

let app = Router::new()
    .route("/api/data", axum::routing::get(get_data))
    .layer(
        ServiceBuilder::new()
            .layer(TraceLayer::new_for_http())        // Logging estruturado
            .layer(CompressionLayer::new())            // Compressão gzip/brotli
            .layer(CorsLayer::permissive())            // CORS permissivo
            .layer(TimeoutLayer::new(Duration::from_secs(30))) // Timeout global
            .layer(
                SetResponseHeaderLayer::overriding(
                    axum::http::header::SERVER,
                    HeaderValue::from_static("Axum/0.7"),
                )
            )
    );

TraceLayer é particularmente útil para debugging em produção, integrando-se com sistemas de tracing distribuído como OpenTelemetry.

4. Extractors fundamentais: extraindo dados da requisição

Os extractors são a forma idiomática de acessar dados da requisição nos handlers:

use axum::{
    extract::{Path, Query, Json, Extension, Form},
    response::Json as JsonResponse,
    routing::get,
    Router,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Deserialize)]
struct UserParams {
    id: u32,
    include_deleted: Option<bool>,
}

#[derive(Serialize)]
struct User {
    id: u32,
    name: String,
    email: String,
}

#[derive(Deserialize)]
struct CreateUser {
    name: String,
    email: String,
}

#[derive(Clone)]
struct AppState {
    db_pool: String,
}

async fn get_user(
    Path(params): Path<UserParams>,
) -> JsonResponse<User> {
    JsonResponse(User {
        id: params.id,
        name: "Alice".to_string(),
        email: "alice@example.com".to_string(),
    })
}

async fn create_user(
    Json(payload): Json<CreateUser>,
) -> JsonResponse<User> {
    JsonResponse(User {
        id: 1,
        name: payload.name,
        email: payload.email,
    })
}

async fn search_users(
    Query(query): Query<HashMap<String, String>>,
    Extension(state): Extension<AppState>,
) -> JsonResponse<Vec<User>> {
    println!("Database: {}", state.db_pool);
    println!("Search query: {:?}", query);
    JsonResponse(vec![])
}

async fn submit_form(
    Form(form): Form<CreateUser>,
) -> JsonResponse<User> {
    JsonResponse(User {
        id: 2,
        name: form.name,
        email: form.email,
    })
}

5. Extractors avançados e customizados

Podemos criar extractors customizados implementando FromRequestParts (para dados que não consomem o corpo) ou FromRequest (quando precisamos ler o corpo):

use axum::{
    extract::{FromRequestParts, FromRequest},
    http::{request::Parts, StatusCode},
    async_trait,
};
use std::sync::Arc;

// Extractor para token JWT do header Authorization
struct AuthToken(String);

#[async_trait]
impl<S> FromRequestParts<S> for AuthToken
where
    S: Send + Sync,
{
    type Rejection = (StatusCode, &'static str);

    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
        let auth_header = parts
            .headers
            .get("Authorization")
            .and_then(|value| value.to_str().ok())
            .ok_or((StatusCode::UNAUTHORIZED, "Missing Authorization header"))?;

        if let Some(token) = auth_header.strip_prefix("Bearer ") {
            Ok(AuthToken(token.to_string()))
        } else {
            Err((StatusCode::BAD_REQUEST, "Invalid Authorization header format"))
        }
    }
}

// Extractor que valida o token e retorna dados do usuário
#[derive(Debug)]
struct AuthenticatedUser {
    id: u32,
    name: String,
}

#[async_trait]
impl<S> FromRequestParts<S> for AuthenticatedUser
where
    S: Send + Sync,
{
    type Rejection = (StatusCode, &'static str);

    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
        let token = AuthToken::from_request_parts(parts, state).await?;

        // Simulação de validação do token
        if token.0 == "valid-token" {
            Ok(AuthenticatedUser {
                id: 42,
                name: "Alice".to_string(),
            })
        } else {
            Err((StatusCode::UNAUTHORIZED, "Invalid token"))
        }
    }
}

async fn protected_handler(
    user: AuthenticatedUser,
) -> String {
    format!("Welcome, {}! (ID: {})", user.name, user.id)
}

6. Ordem de execução e estado compartilhado

A ordem dos layers é crucial: layers externos executam primeiro no request e último no response. O estado compartilhado deve ser injetado antes dos middlewares que precisam dele:

use axum::extract::State;

#[derive(Clone)]
struct AppConfig {
    database_url: String,
    api_key: String,
}

async fn health_check(
    State(config): State<AppConfig>,
) -> String {
    format!("Connected to: {}", config.database_url)
}

async fn main() {
    let config = AppConfig {
        database_url: "postgres://localhost/mydb".to_string(),
        api_key: "secret-key".to_string(),
    };

    let app = Router::new()
        .route("/health", axum::routing::get(health_check))
        .with_state(config) // Estado global acessível via State extractor
        .layer(
            ServiceBuilder::new()
                .layer(TraceLayer::new_for_http())
                .layer(TimeoutLayer::new(Duration::from_secs(10)))
        );
}

Boas práticas:
- Use State para dados compartilhados entre handlers
- Prefira Extension para dados injetados por middlewares (ex: usuário autenticado)
- Evite dependências circulares entre middlewares

7. Tratamento de erros e fallbacks em middlewares

Middlewares podem lidar com erros de forma centralizada:

use axum::response::IntoResponse;
use std::convert::Infallible;

async fn auth_middleware(
    req: Request<Body>,
    next: Next,
) -> Result<Response, (StatusCode, Json<serde_json::Value>)> {
    let token = req.headers()
        .get("Authorization")
        .and_then(|v| v.to_str().ok());

    match token {
        Some(_) => Ok(next.run(req).await),
        None => Err((
            StatusCode::UNAUTHORIZED,
            Json(serde_json::json!({
                "error": "Authentication required",
                "code": "UNAUTHORIZED"
            })),
        )),
    }
}

// Middleware de fallback para tratamento global de erros
async fn error_handler_middleware(
    req: Request<Body>,
    next: Next,
) -> Response {
    let response = next.run(req).await;

    if response.status().is_server_error() {
        // Logar o erro e retornar resposta padronizada
        println!("Server error: {}", response.status());
        (StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong").into_response()
    } else {
        response
    }
}

8. Exemplo completo: aplicação com middleware chain e extractores

use axum::{
    extract::{Path, Query, Extension, State},
    response::Json,
    routing::{get, post},
    Router,
    middleware,
    http::StatusCode,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tower::ServiceBuilder;
use tower_http::{cors::CorsLayer, compression::CompressionLayer, trace::TraceLayer};

#[derive(Clone)]
struct AppState {
    db: Arc<dyn Database>,
}

trait Database: Send + Sync {
    fn get_user(&self, id: u32) -> Option<User>;
}

#[derive(Serialize, Deserialize, Clone)]
struct User {
    id: u32,
    name: String,
}

#[derive(Deserialize)]
struct UserQuery {
    id: u32,
    fields: Option<String>,
}

#[derive(Clone)]
struct AuthUser {
    id: u32,
    name: String,
}

// Middleware de autenticação
async fn auth_middleware(
    req: axum::http::Request<axum::body::Body>,
    next: middleware::Next,
) -> Result<axum::response::Response, (StatusCode, Json<serde_json::Value>)> {
    let user = AuthUser { id: 1, name: "Admin".to_string() };
    let mut req = req;
    req.extensions_mut().insert(user);
    Ok(next.run(req).await)
}

async fn get_user_handler(
    Path(id): Path<u32>,
    Query(query): Query<UserQuery>,
    Extension(auth_user): Extension<AuthUser>,
    State(state): State<AppState>,
) -> Result<Json<User>, (StatusCode, String)> {
    println!("User {} accessed by {}", id, auth_user.name);

    match state.db.get_user(id) {
        Some(user) => Ok(Json(user)),
        None => Err((StatusCode::NOT_FOUND, "User not found".to_string())),
    }
}

#[tokio::main]
async fn main() {
    let state = AppState {
        db: Arc::new(MockDatabase),
    };

    let app = Router::new()
        .route("/users/:id", get(get_user_handler))
        .with_state(state)
        .layer(
            ServiceBuilder::new()
                .layer(TraceLayer::new_for_http())
                .layer(CompressionLayer::new())
                .layer(CorsLayer::permissive())
                .layer(middleware::from_fn(auth_middleware))
        );

    println!("Server running on http://localhost:3000");
    axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

struct MockDatabase;

impl Database for MockDatabase {
    fn get_user(&self, id: u32) -> Option<User> {
        if id == 1 {
            Some(User { id: 1, name: "Alice".to_string() })
        } else {
            None
        }
    }
}

Referências