Spring BootRefresh Tokens
Spring Boot

Refresh Tokens

Refresh tokens enable long-lived authentication without storing long-lived access tokens. The access token is short-lived (5–15 minutes) and stateless; the refresh token is long-lived (days to weeks) and stored server-side so it can be revoked. This entry covers refresh token storage, issuance, rotation, family-based theft detection, revocation, and Redis-backed storage for high-throughput applications.

Refresh Token Entity and Repository

Store refresh tokens server-side so they can be revoked individually or in bulk. Hash or encrypt the token value before storage so a database breach does not expose active tokens. Index the token hash column and the user-email column for fast lookup and bulk revocation.
Java
// ── Refresh token entity ─────────────────────────────────────────────
@Entity
@Table(
    name = "refresh_tokens",
    indexes = {
        @Index(name = "idx_rt_token_hash",
               columnList = "token_hash"),
        @Index(name = "idx_rt_user_email",
               columnList = "user_email"),
        @Index(name = "idx_rt_family",
               columnList = "family_id"),
        @Index(name = "idx_rt_expires_at",
               columnList = "expires_at")
    }
)
@Getter @Setter @NoArgsConstructor
public class RefreshToken {

    @Id @GeneratedValue(strategy = GenerationType.IDENTITY)
    private Long id;

    // SHA-256 of the raw token — never store plain text
    @Column(name = "token_hash", nullable = false,
            unique = true, length = 64)
    private String tokenHash;

    @Column(name = "user_email", nullable = false, length = 255)
    private String userEmail;

    // Family ID groups all tokens from the same login session.
    // Used for theft detection: if a revoked token in a family
    // is presented, the entire family is invalidated.
    @Column(name = "family_id", nullable = false, length = 36)
    private String familyId;

    @Column(name = "expires_at", nullable = false)
    private Instant expiresAt;

    @Column(nullable = false)
    private boolean revoked = false;

    @Column(name = "created_at", nullable = false)
    private Instant createdAt;

    @Column(name = "used_at")
    private Instant usedAt;

    @Column(name = "device_info", length = 255)
    private String deviceInfo;

    @Column(name = "ip_address", length = 45)
    private String ipAddress;
}

// ── Repository ────────────────────────────────────────────────────────
public interface RefreshTokenRepository
        extends JpaRepository<RefreshToken, Long> {

    Optional<RefreshToken> findByTokenHash(String tokenHash);

    List<RefreshToken> findByUserEmailAndRevokedFalse(
        String email);

    List<RefreshToken> findByFamilyId(String familyId);

    @Modifying
    @Query("UPDATE RefreshToken t SET t.revoked = true " +
           "WHERE t.familyId = :familyId")
    void revokeFamily(@Param("familyId") String familyId);

    @Modifying
    @Query("UPDATE RefreshToken t SET t.revoked = true " +
           "WHERE t.userEmail = :email")
    void revokeAllForUser(@Param("email") String email);

    @Modifying
    @Query("DELETE FROM RefreshToken t " +
           "WHERE t.expiresAt < :now OR t.revoked = true")
    int deleteExpiredAndRevoked(@Param("now") Instant now);
}

Token Issuance and Rotation

Issue a new refresh token on every use (token rotation). Invalidate the old one immediately after issuing the new one. Rotation means each token can only be used once — if an attacker steals a token and uses it first, the legitimate client's next request fails and the theft is detected.
Java
@Service
@RequiredArgsConstructor
@Slf4j
public class RefreshTokenService {

    private final RefreshTokenRepository tokenRepo;
    private final JwtService             jwtService;

    @Value("${app.jwt.refresh-token-expiry:604800}")
    private long refreshTokenExpirySeconds;

    // ── Create a new refresh token (first login in family) ────────────
    @Transactional
    public TokenPair createTokenPair(AppUser user,
                                      HttpServletRequest request) {
        String familyId = UUID.randomUUID().toString();
        return issueTokenPair(user, familyId, request);
    }

    // ── Rotate: validate old token, issue new one in same family ──────
    @Transactional
    public TokenPair rotate(String rawRefreshToken,
                             HttpServletRequest request) {

        String hash = hashToken(rawRefreshToken);
        RefreshToken stored = tokenRepo.findByTokenHash(hash)
            .orElseThrow(() -> new InvalidTokenException(
                "Refresh token not found"));

        // ── Theft detection: already-used token presented ─────────────
        if (stored.isRevoked()) {
            // Someone is replaying a previously rotated token
            // Invalidate the entire family — all sessions for this login
            log.warn("Refresh token reuse detected — " +
                "revoking family {} for user {}",
                stored.getFamilyId(), stored.getUserEmail());
            tokenRepo.revokeFamily(stored.getFamilyId());
            throw new TokenTheftException(
                "Suspicious activity detected. " +
                "Please log in again.");
        }

        // ── Check expiry ───────────────────────────────────────────────
        if (stored.getExpiresAt().isBefore(Instant.now())) {
            stored.setRevoked(true);
            tokenRepo.save(stored);
            throw new InvalidTokenException("Refresh token expired");
        }

        // ── Revoke the old token ───────────────────────────────────────
        stored.setRevoked(true);
        stored.setUsedAt(Instant.now());
        tokenRepo.save(stored);

        // ── Issue new token in the same family ────────────────────────
        AppUser user = loadUser(stored.getUserEmail());
        return issueTokenPair(user, stored.getFamilyId(), request);
    }

    private TokenPair issueTokenPair(AppUser user, String familyId,
                                      HttpServletRequest request) {
        String rawToken = generateRawToken();
        String hash     = hashToken(rawToken);

        RefreshToken token = new RefreshToken();
        token.setTokenHash(hash);
        token.setUserEmail(user.getEmail());
        token.setFamilyId(familyId);
        token.setExpiresAt(
            Instant.now().plusSeconds(refreshTokenExpirySeconds));
        token.setCreatedAt(Instant.now());
        token.setDeviceInfo(
            request.getHeader(HttpHeaders.USER_AGENT));
        token.setIpAddress(getClientIp(request));
        tokenRepo.save(token);

        String accessToken =
            jwtService.generateAccessToken(user);

        return new TokenPair(accessToken, rawToken,
            user.getId(), user.getEmail());
    }

    private String generateRawToken() {
        byte[] bytes = new byte[64];
        new SecureRandom().nextBytes(bytes);
        return Base64.getUrlEncoder()
            .withoutPadding().encodeToString(bytes);
    }

    private String hashToken(String raw) {
        try {
            MessageDigest md =
                MessageDigest.getInstance("SHA-256");
            return HexFormat.of().formatHex(
                md.digest(raw.getBytes(StandardCharsets.UTF_8)));
        } catch (NoSuchAlgorithmException e) {
            throw new IllegalStateException(e);
        }
    }

    private String getClientIp(HttpServletRequest req) {
        String fwd = req.getHeader("X-Forwarded-For");
        return fwd != null ? fwd.split(",")[0].trim()
                           : req.getRemoteAddr();
    }

    public record TokenPair(
        String accessToken,
        String refreshToken,
        Long   userId,
        String email
    ) {}
}

Refresh Token Revocation

Revoke refresh tokens on logout, password change, suspicious activity, or admin action. Because access tokens are stateless they cannot be revoked before expiry — keep their TTL short (5–15 minutes). For instant access token revocation, maintain a short-lived blocklist in Redis.
Java
@Service
@RequiredArgsConstructor
@Slf4j
public class TokenRevocationService {

    private final RefreshTokenRepository  tokenRepo;
    private final StringRedisTemplate     redisTemplate;

    @Value("${app.jwt.access-token-expiry:900}")
    private long accessTokenExpirySeconds;

    // ── Revoke single refresh token (logout from one device) ──────────
    @Transactional
    public void revokeRefreshToken(String rawToken) {
        String hash = hashToken(rawToken);
        tokenRepo.findByTokenHash(hash).ifPresent(t -> {
            t.setRevoked(true);
            tokenRepo.save(t);
        });
    }

    // ── Revoke all refresh tokens for user (logout all devices) ───────
    @Transactional
    public void revokeAllForUser(String email) {
        tokenRepo.revokeAllForUser(email);
        log.info("Revoked all refresh tokens for: {}", email);
    }

    // ── Blocklist an access token until its expiry ────────────────────
    // Use when access token must be invalidated before expiry:
    // - password change
    // - account suspension
    // - detected compromise
    public void blocklistAccessToken(String jti, long remainingSeconds) {
        String key = "blocked:jwt:" + jti;
        redisTemplate.opsForValue()
            .set(key, "1",
                Duration.ofSeconds(remainingSeconds));
        log.info("Access token {} added to blocklist " +
                 "for {} seconds", jti, remainingSeconds);
    }

    public boolean isAccessTokenBlocked(String jti) {
        return Boolean.TRUE.equals(
            redisTemplate.hasKey("blocked:jwt:" + jti));
    }

    // ── Trigger revocation on password change ─────────────────────────
    @Transactional
    public void onPasswordChange(String email,
                                  String currentJti) {
        revokeAllForUser(email);
        // Blocklist the current access token for its remaining TTL
        blocklistAccessToken(currentJti, accessTokenExpirySeconds);
        log.info("Revoked all tokens on password change for: {}",
            email);
    }

    // ── Scheduled cleanup ─────────────────────────────────────────────
    @Scheduled(cron = "0 0 4 * * *")   // 4 AM daily
    @Transactional
    public void cleanUp() {
        int deleted = tokenRepo.deleteExpiredAndRevoked(
            Instant.now());
        log.info("Deleted {} expired/revoked refresh tokens",
            deleted);
    }
}

// ── Update JWT filter to check blocklist ─────────────────────────────
@Component
@RequiredArgsConstructor
public class JwtAuthenticationFilter extends OncePerRequestFilter {

    private final JwtService              jwtService;
    private final TokenRevocationService  revocationService;
    private final CustomUserDetailsService userDetailsService;

    @Override
    protected void doFilterInternal(HttpServletRequest  request,
                                    HttpServletResponse response,
                                    FilterChain         chain)
            throws ServletException, IOException {
        String jwt = extractToken(request);
        if (jwt == null) {
            chain.doFilter(request, response);
            return;
        }
        try {
            String jti = jwtService.extractClaim(
                jwt, Claims::getId);

            // Check blocklist before further processing
            if (jti != null &&
                    revocationService.isAccessTokenBlocked(jti)) {
                chain.doFilter(request, response);
                return;
            }

            String email = jwtService.extractSubject(jwt);
            if (email != null && SecurityContextHolder
                    .getContext().getAuthentication() == null) {

                UserDetails user =
                    userDetailsService.loadUserByUsername(email);
                if (jwtService.isValid(jwt, user)) {
                    var authToken =
                        new UsernamePasswordAuthenticationToken(
                            user, null, user.getAuthorities());
                    authToken.setDetails(
                        new WebAuthenticationDetailsSource()
                            .buildDetails(request));
                    SecurityContextHolder.getContext()
                        .setAuthentication(authToken);
                }
            }
        } catch (JwtException ignored) { }

        chain.doFilter(request, response);
    }

    private String extractToken(HttpServletRequest req) {
        String header = req.getHeader(HttpHeaders.AUTHORIZATION);
        return header != null && header.startsWith("Bearer ")
            ? header.substring(7) : null;
    }
}

Redis-Backed Refresh Tokens

For high-throughput applications, store refresh tokens in Redis instead of a relational database. Redis TTL handles expiry automatically, atomic operations prevent race conditions, and reads are an order of magnitude faster than a SQL lookup. Use a sorted set for listing active sessions per user.
Java
@Service
@RequiredArgsConstructor
@Slf4j
public class RedisRefreshTokenService {

    private final RedisTemplate<String, String> redisTemplate;
    private final JwtService                    jwtService;

    private static final String TOKEN_KEY   = "rt:token:";
    private static final String USER_KEY    = "rt:user:";
    private static final String FAMILY_KEY  = "rt:family:";

    @Value("${app.jwt.refresh-token-expiry:604800}")
    private long expirySeconds;

    // ── Create token ──────────────────────────────────────────────────
    public TokenPair create(AppUser user,
                             HttpServletRequest request) {
        String familyId  = UUID.randomUUID().toString();
        return issue(user, familyId, request);
    }

    // ── Rotate token ──────────────────────────────────────────────────
    public TokenPair rotate(String rawToken,
                             HttpServletRequest request) {
        String hash     = hashToken(rawToken);
        String tokenKey = TOKEN_KEY + hash;

        String tokenData = redisTemplate.opsForValue()
            .get(tokenKey);

        if (tokenData == null) {
            // Check if this was a revoked token (theft detection)
            String revokedKey = "rt:revoked:" + hash;
            if (Boolean.TRUE.equals(
                    redisTemplate.hasKey(revokedKey))) {
                // Extract family from the revoked record
                String familyId = redisTemplate.opsForValue()
                    .get(revokedKey);
                if (familyId != null) {
                    revokeFamily(familyId);
                    log.warn("Token reuse detected — " +
                        "revoked family {}", familyId);
                }
                throw new TokenTheftException(
                    "Security alert: please log in again");
            }
            throw new InvalidTokenException(
                "Refresh token not found or expired");
        }

        // Parse stored data: "email:familyId"
        String[] parts    = tokenData.split(":", 2);
        String   email    = parts[0];
        String   familyId = parts[1];

        // Mark as revoked (keep for theft detection window)
        redisTemplate.opsForValue().set(
            "rt:revoked:" + hash, familyId,
            Duration.ofSeconds(expirySeconds));
        redisTemplate.delete(tokenKey);
        redisTemplate.opsForZSet()
            .remove(USER_KEY + email, hash);

        AppUser user = loadUser(email);
        return issue(user, familyId, request);
    }

    private TokenPair issue(AppUser user, String familyId,
                             HttpServletRequest request) {
        String rawToken = generateRawToken();
        String hash     = hashToken(rawToken);
        long   score    = Instant.now()
            .plusSeconds(expirySeconds).toEpochMilli();

        // Store token data
        redisTemplate.opsForValue().set(
            TOKEN_KEY + hash,
            user.getEmail() + ":" + familyId,
            Duration.ofSeconds(expirySeconds));

        // Track all tokens for this user (sorted by expiry)
        redisTemplate.opsForZSet().add(
            USER_KEY + user.getEmail(), hash, score);

        // Track all tokens in this family
        redisTemplate.opsForSet().add(
            FAMILY_KEY + familyId, hash);
        redisTemplate.expire(
            FAMILY_KEY + familyId,
            Duration.ofSeconds(expirySeconds));

        String accessToken =
            jwtService.generateAccessToken(user);
        return new TokenPair(accessToken, rawToken,
            user.getId(), user.getEmail());
    }

    public void revokeAll(String email) {
        Set<Object> hashes =
            redisTemplate.opsForZSet()
                .range(USER_KEY + email, 0, -1);
        if (hashes != null) {
            hashes.forEach(h ->
                redisTemplate.delete(TOKEN_KEY + h));
        }
        redisTemplate.delete(USER_KEY + email);
        log.info("Revoked all Redis refresh tokens for: {}",
            email);
    }

    private void revokeFamily(String familyId) {
        Set<Object> hashes = redisTemplate.opsForSet()
            .members(FAMILY_KEY + familyId);
        if (hashes != null) {
            hashes.forEach(h ->
                redisTemplate.delete(TOKEN_KEY + h));
        }
        redisTemplate.delete(FAMILY_KEY + familyId);
    }
}

Silent Refresh and Token Expiry Handling

The client should silently refresh the access token before it expires to avoid interrupting the user. Implement silent refresh by detecting token expiry in a response interceptor and automatically calling the refresh endpoint. Return a 401 with a machine-readable error code so the client knows to refresh rather than redirect to login.
Java
// ── Structured token expiry response ─────────────────────────────────
@RestControllerAdvice
public class TokenExceptionHandler {

    // ── Access token expired → client should use refresh token ────────
    @ExceptionHandler(ExpiredJwtException.class)
    public ResponseEntity<TokenErrorResponse> handleExpired(
            ExpiredJwtException ex, HttpServletRequest req) {
        return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
            .header("WWW-Authenticate",
                "Bearer error="invalid_token"," +
                " error_description="Token expired"")
            .body(new TokenErrorResponse(
                "TOKEN_EXPIRED",
                "Access token has expired",
                req.getRequestURI()
            ));
    }

    // ── Invalid / tampered token ──────────────────────────────────────
    @ExceptionHandler({MalformedJwtException.class,
                       SignatureException.class,
                       UnsupportedJwtException.class})
    public ResponseEntity<TokenErrorResponse> handleInvalid(
            JwtException ex, HttpServletRequest req) {
        return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
            .body(new TokenErrorResponse(
                "TOKEN_INVALID",
                "Access token is invalid",
                req.getRequestURI()
            ));
    }

    // ── Refresh token stolen / reused ──────────────────────────────────
    @ExceptionHandler(TokenTheftException.class)
    public ResponseEntity<TokenErrorResponse> handleTheft(
            TokenTheftException ex, HttpServletRequest req) {
        return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
            .body(new TokenErrorResponse(
                "TOKEN_THEFT_DETECTED",
                ex.getMessage(),
                req.getRequestURI()
            ));
    }

    public record TokenErrorResponse(
        String code,
        String message,
        String path
    ) {}
}

// ── Client-side silent refresh (TypeScript / Axios interceptor) ───────
// axiosInstance.interceptors.response.use(
//   response => response,
//   async error => {
//     const original = error.config;
//     if (error.response?.status === 401 &&
//         error.response?.data?.code === 'TOKEN_EXPIRED' &&
//         !original._retry) {
//       original._retry = true;
//       const { data } = await axios.post('/api/v1/auth/refresh',
//         { refreshToken: getRefreshToken() });
//       setAccessToken(data.accessToken);
//       original.headers.Authorization =
//         'Bearer ' + data.accessToken;
//       return axiosInstance(original);
//     }
//     return Promise.reject(error);
//   }
// );