Sharding in Java Spring Boot - Part 1: Foundation and Routing Infrastructure
How Application-Level Sharding Is Built from the Ground Up
Table of Contents
- Project Structure and Dependencies
- Configuration: DataSource Per Shard with HikariCP
- The Shard Context: ThreadLocal State Management
- The Routing DataSource: AbstractRoutingDataSource
- Shard Key Resolution
- Custom Annotations and AOP: Automatic Routing
- Entity Layer
- Repository Layer
- Service Layer: Putting It All Together
- REST Controller Layer
- Common Mistakes and How Spring Boot Exposes Them
Before We Start: The Mental Model
In a non-sharded Spring Boot app, the flow is:
HTTP Request
|
v
Controller (@RestController)
|
v
Service (@Service, @Transactional)
|
v
Repository (Spring Data JPA)
|
v
ONE DataSource -> ONE Database
In a sharded Spring Boot app, one component is added between Repository and DataSource:
HTTP Request
|
v
Controller
|
v
Service (@ShardedOperation -> AOP sets ShardContext)
|
v
Repository (Spring Data JPA - unchanged, knows nothing about shards)
|
v
ShardRoutingDataSource (reads ShardContext, routes to correct DataSource)
|
+--> DataSource for Shard 0 --> PostgreSQL Shard 0
+--> DataSource for Shard 1 --> PostgreSQL Shard 1
+--> DataSource for Shard 2 --> PostgreSQL Shard 2
+--> DataSource for Shard 3 --> PostgreSQL Shard 3
The repository code does not change. The entity code does not change. What changes is:
- You configure multiple DataSources (one per shard)
- You implement AbstractRoutingDataSource to choose between them
- You manage a ThreadLocal that holds the current shard number
- You add an AOP aspect that sets the ThreadLocal before any database call
1. Project Structure and Dependencies
Maven pom.xml
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.2.5</version>
</parent>
<dependencies>
<!-- Core Spring Boot -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-jpa</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!-- AOP - required for @ShardedOperation annotation processing -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<!-- Database driver -->
<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
</dependency>
<!-- Schema migrations -->
<dependency>
<groupId>org.flywaydb</groupId>
<artifactId>flyway-core</artifactId>
</dependency>
<!-- Circuit breaker per shard -->
<dependency>
<groupId>io.github.resilience4j</groupId>
<artifactId>resilience4j-spring-boot3</artifactId>
<version>2.2.0</version>
</dependency>
<!-- Metrics -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-actuator</artifactId>
</dependency>
<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-registry-prometheus</artifactId>
</dependency>
<!-- Kafka for outbox relay -->
<dependency>
<groupId>org.springframework.kafka</groupId>
<artifactId>spring-kafka</artifactId>
</dependency>
<!-- Lombok -->
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<!-- Testing -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>postgresql</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
</dependencies>Project Package Structure
com.example.sharding
config/
ShardProperties.java <- Maps application.yml shard config
ShardDataSourceConfig.java <- Creates HikariCP DataSource per shard
context/
ShardContext.java <- ThreadLocal shard state
ShardContextHolder.java <- Safe API: withKey(), withShard(), scatterGather()
routing/
ShardKeyResolver.java <- hash(key) mod numShards
ShardRoutingDataSource.java <- Extends AbstractRoutingDataSource
annotation/
ShardedOperation.java <- @ShardedOperation(shardKeyParam="userId")
ShardKey.java <- @ShardKey on method parameter
aspect/
ShardingAspect.java <- AOP: reads annotation, sets ShardContext
entity/
User.java
Order.java
OutboxEvent.java
repository/
UserRepository.java
OrderRepository.java
OutboxRepository.java
EmailIndexRepository.java <- Non-sharded secondary index
service/
UserService.java
OrderService.java
CrossShardQueryService.java
id/
SnowflakeIdGenerator.java <- Globally unique 64-bit IDs
health/
ShardHealthIndicator.java
resources/
application.yml
db/migration/
V1__create_users_table.sql
V2__create_orders_table.sql
V3__create_outbox_table.sql
application.yml
spring:
application:
name: sharded-app
# Disable Spring Boot's auto-configuration of a single DataSource.
# We are creating multiple DataSources manually.
autoconfigure:
exclude:
- org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration
- org.springframework.boot.autoconfigure.orm.jpa.HibernateJpaAutoConfiguration
jpa:
hibernate:
ddl-auto: validate # Flyway manages schema, not Hibernate
show-sql: false
properties:
hibernate:
dialect: org.hibernate.dialect.PostgreSQLDialect
jdbc.batch_size: 50
order_inserts: true
order_updates: true
app:
sharding:
num-shards: 4
shards:
- id: 0
url: jdbc:postgresql://localhost:5432/shard0_db
username: app_user
password: ${SHARD0_PASSWORD:localpassword}
pool-size: 20
- id: 1
url: jdbc:postgresql://localhost:5433/shard1_db
username: app_user
password: ${SHARD1_PASSWORD:localpassword}
pool-size: 20
- id: 2
url: jdbc:postgresql://localhost:5434/shard2_db
username: app_user
password: ${SHARD2_PASSWORD:localpassword}
pool-size: 20
- id: 3
url: jdbc:postgresql://localhost:5435/shard3_db
username: app_user
password: ${SHARD3_PASSWORD:localpassword}
pool-size: 20
snowflake:
machine-id: ${POD_INDEX:0} # Set per Kubernetes pod replica via env var
management:
endpoints:
web:
exposure:
include: health, metrics, prometheus
endpoint:
health:
show-details: always
resilience4j:
circuitbreaker:
instances:
shard-0:
slidingWindowSize: 10
failureRateThreshold: 50
waitDurationInOpenState: 30s
permittedNumberOfCallsInHalfOpenState: 3
automaticTransitionFromOpenToHalfOpenEnabled: true
shard-1:
slidingWindowSize: 10
failureRateThreshold: 50
waitDurationInOpenState: 30s
shard-2:
slidingWindowSize: 10
failureRateThreshold: 50
waitDurationInOpenState: 30s
shard-3:
slidingWindowSize: 10
failureRateThreshold: 50
waitDurationInOpenState: 30s2. Configuration: DataSource Per Shard with HikariCP
ShardProperties.java - Maps application.yml to a Java class
@ConfigurationProperties(prefix = "app.sharding")
@Component
@Data
public class ShardProperties {
private int numShards;
private List<ShardConfig> shards = new ArrayList<>();
@Data
public static class ShardConfig {
private int id;
private String url;
private String username;
private String password;
private int poolSize = 10;
}
}ShardDataSourceConfig.java - Creates one HikariCP DataSource per shard
@Configuration
@EnableConfigurationProperties(ShardProperties.class)
@Slf4j
public class ShardDataSourceConfig {
private final ShardProperties shardProperties;
public ShardDataSourceConfig(ShardProperties shardProperties) {
this.shardProperties = shardProperties;
}
/**
* Creates one named HikariCP DataSource per shard.
* Returns a Map so other beans can reference individual shards.
*/
@Bean("shardDataSources")
public Map<Integer, DataSource> shardDataSources() {
Map<Integer, DataSource> map = new LinkedHashMap<>();
for (ShardProperties.ShardConfig config : shardProperties.getShards()) {
HikariDataSource ds = buildDataSource(config);
map.put(config.getId(), ds);
log.info("Initialized HikariCP pool for shard {}: url={}",
config.getId(), config.getUrl());
}
return map;
}
private HikariDataSource buildDataSource(ShardProperties.ShardConfig config) {
HikariConfig hikari = new HikariConfig();
hikari.setJdbcUrl(config.getUrl());
hikari.setUsername(config.getUsername());
hikari.setPassword(config.getPassword());
// Pool sizing
hikari.setMaximumPoolSize(config.getPoolSize());
hikari.setMinimumIdle(config.getPoolSize() / 2);
// Timeouts
hikari.setConnectionTimeout(3_000); // 3s: time to get connection from pool
hikari.setIdleTimeout(600_000); // 10min: idle connection eviction
hikari.setMaxLifetime(1_800_000); // 30min: max age before forced recycling
hikari.setKeepaliveTime(60_000); // 1min: periodic keepalive ping
// Validation
hikari.setConnectionTestQuery("SELECT 1");
hikari.setValidationTimeout(2_000);
// Pool name - appears in JMX and Micrometer metrics
hikari.setPoolName("HikariPool-Shard-" + config.getId());
return new HikariDataSource(hikari);
}
/**
* The routing DataSource - the single DataSource that Spring JPA will use.
* It reads ShardContext on every DB call and delegates to the right shard DataSource.
*/
@Bean
@Primary
public DataSource routingDataSource(
@Qualifier("shardDataSources") Map<Integer, DataSource> shardDataSources) {
ShardRoutingDataSource routing = new ShardRoutingDataSource();
// AbstractRoutingDataSource needs Map<Object, Object>
Map<Object, Object> targets = new HashMap<>();
shardDataSources.forEach((k, v) -> targets.put(k, v));
routing.setTargetDataSources(targets);
routing.setDefaultTargetDataSource(shardDataSources.get(0));
routing.afterPropertiesSet(); // Must call this to initialize internal lookup map
return routing;
}
/**
* Manual EntityManagerFactory because we disabled auto-configuration.
* Must point to the routing DataSource.
*/
@Bean
public LocalContainerEntityManagerFactoryBean entityManagerFactory(
@Qualifier("routingDataSource") DataSource dataSource) {
LocalContainerEntityManagerFactoryBean emf =
new LocalContainerEntityManagerFactoryBean();
emf.setDataSource(dataSource);
emf.setPackagesToScan("com.example.sharding.entity");
HibernateJpaVendorAdapter vendorAdapter = new HibernateJpaVendorAdapter();
vendorAdapter.setGenerateDdl(false);
emf.setJpaVendorAdapter(vendorAdapter);
Properties props = new Properties();
props.put("hibernate.dialect", "org.hibernate.dialect.PostgreSQLDialect");
props.put("hibernate.jdbc.batch_size", "50");
props.put("hibernate.order_inserts", "true");
props.put("hibernate.order_updates", "true");
emf.setJpaProperties(props);
return emf;
}
@Bean
public PlatformTransactionManager transactionManager(
EntityManagerFactory entityManagerFactory) {
return new JpaTransactionManager(entityManagerFactory);
}
}3. The Shard Context: ThreadLocal State Management
ShardContext.java
/**
* ThreadLocal-based holder for the current request's shard ID.
*
* CRITICAL RULE: Every call to set() MUST be followed by clear() in a finally block.
* Failure to clear causes "ThreadLocal leak" in thread pools:
* Thread T1 sets shard=2, processes a request, returns to the pool.
* Thread T1 is reused for the next request. ShardContext still shows shard=2.
* The next request accidentally queries shard 2 regardless of its actual shard key.
* This is silent data corruption. The bug is intermittent and very hard to debug.
*
* SAFE pattern: Always use ShardContextHolder.withKey() which guarantees cleanup.
* UNSAFE pattern: Call ShardContext.set() directly without a corresponding finally.
*/
public final class ShardContext {
private static final ThreadLocal<Integer> CURRENT_SHARD = new ThreadLocal<>();
private ShardContext() {}
public static void set(int shardId) {
CURRENT_SHARD.set(shardId);
}
public static Integer get() {
return CURRENT_SHARD.get();
}
/**
* Use remove() not set(null).
* set(null) leaves the ThreadLocal entry in the ThreadLocalMap.
* remove() actually deletes the entry, preventing memory leaks.
*/
public static void clear() {
CURRENT_SHARD.remove();
}
public static boolean isSet() {
return CURRENT_SHARD.get() != null;
}
}ShardContextHolder.java - The Safe Public API
/**
* Public API for shard context management.
* All database operations should go through these methods, not ShardContext directly.
*
* All methods guarantee ShardContext.clear() is called, even if the operation throws.
*/
@Component
@Slf4j
public class ShardContextHolder {
private final ShardKeyResolver shardKeyResolver;
private final ShardProperties shardProperties;
public ShardContextHolder(ShardKeyResolver shardKeyResolver,
ShardProperties shardProperties) {
this.shardKeyResolver = shardKeyResolver;
this.shardProperties = shardProperties;
}
/**
* Execute an operation on the shard determined by a long key (user_id, order_id).
*
* Usage:
* User user = shardContextHolder.withKey(userId, () -> userRepository.findById(userId));
*/
public <T> T withKey(long shardKey, Supplier<T> operation) {
int shardId = shardKeyResolver.resolve(shardKey);
return withShard(shardId, operation);
}
/**
* Execute a void operation on the shard for the given key.
*/
public void withKey(long shardKey, Runnable operation) {
int shardId = shardKeyResolver.resolve(shardKey);
withShard(shardId, () -> { operation.run(); return null; });
}
/**
* Execute on a specific, named shard.
* Prefer withKey() when you have a natural shard key.
*/
public <T> T withShard(int shardId, Supplier<T> operation) {
// Detect nested shard context (one service calling another)
Integer previous = ShardContext.get();
if (previous != null && !previous.equals(shardId)) {
log.warn("Nested shard context detected. " +
"Outer shard: {}, Requested shard: {}. " +
"This may indicate a cross-shard operation.", previous, shardId);
}
ShardContext.set(shardId);
try {
return operation.get();
} finally {
// If there was a previous context, restore it. Otherwise clear.
if (previous != null) {
ShardContext.set(previous);
} else {
ShardContext.clear();
}
}
}
/**
* Execute on ALL shards sequentially and collect results.
* This is the scatter-gather pattern.
*
* Example:
* List<Long> counts = scatterGather(shardId ->
* userRepository.countActiveUsers(), numShards);
* long total = counts.stream().mapToLong(Long::longValue).sum();
*/
public <T> List<T> scatterGather(Function<Integer, T> operation) {
int numShards = shardProperties.getNumShards();
List<T> results = new ArrayList<>(numShards);
for (int shardId = 0; shardId < numShards; shardId++) {
final int id = shardId;
T result = withShard(id, () -> operation.apply(id));
results.add(result);
}
return results;
}
/**
* Scatter-gather in parallel using a provided ExecutorService.
* Results from failed shards are null (caller must handle).
* Waits at most timeoutSeconds for all shards.
*/
public <T> List<T> scatterGatherParallel(
Function<Integer, T> operation,
ExecutorService executor,
int timeoutSeconds) {
int numShards = shardProperties.getNumShards();
List<CompletableFuture<T>> futures = IntStream.range(0, numShards)
.mapToObj(shardId ->
CompletableFuture.supplyAsync(
() -> withShard(shardId, () -> operation.apply(shardId)),
executor
)
.orTimeout(timeoutSeconds, TimeUnit.SECONDS)
.exceptionally(ex -> {
log.error("Scatter-gather shard {} failed: {}", shardId, ex.getMessage());
return null;
})
)
.toList();
return futures.stream()
.map(f -> {
try { return f.get(); }
catch (Exception e) { return null; }
})
.toList();
}
}4. The Routing DataSource: AbstractRoutingDataSource
/**
* Spring's AbstractRoutingDataSource routes every JDBC call to one of the
* configured target DataSources by calling determineCurrentLookupKey().
*
* Spring calls this method BEFORE every:
* - Connection acquisition (including @Transactional open)
* - Direct JdbcTemplate execution
* - Spring Data JPA repository method
*
* The returned key must match a key in the targetDataSources Map.
*/
@Slf4j
public class ShardRoutingDataSource extends AbstractRoutingDataSource {
@Override
protected Object determineCurrentLookupKey() {
Integer shardId = ShardContext.get();
if (shardId == null) {
// No shard was set. This is almost always a programming error.
// Returning null makes AbstractRoutingDataSource use the defaultTargetDataSource.
log.warn("ShardContext not set. Using default shard (0). " +
"All requests should have a shard context. " +
"Caller: {}", getCallerMethod());
return null; // Falls back to defaultTargetDataSource (shard 0)
}
log.debug("Routing to shard {}", shardId);
return shardId;
}
private String getCallerMethod() {
// Show the immediate non-framework caller for debugging
StackTraceElement[] stack = Thread.currentThread().getStackTrace();
for (int i = 3; i < Math.min(stack.length, 10); i++) {
String cls = stack[i].getClassName();
if (!cls.startsWith("org.springframework") &&
!cls.startsWith("com.zaxxer") &&
!cls.startsWith("java.")) {
return stack[i].toString();
}
}
return "unknown";
}
}How @Transactional Interacts with the Routing DataSource
This is the most important thing to understand:
Timeline of a @Transactional method on a sharded DataSource:
1. AOP ShardingAspect runs (ORDER=1): reads shard key, sets ShardContext to shard=2
2. @Transactional AOP runs (ORDER=MAX_VALUE): opens a DB connection
- AbstractRoutingDataSource.getConnection() is called
- determineCurrentLookupKey() returns 2 (from ShardContext)
- Connection is acquired from Shard 2's HikariCP pool
- Spring binds this connection to the current thread
3. Method body executes:
- userRepository.findById(...) -> uses bound connection -> Shard 2 ✓
- orderRepository.save(...) -> uses bound connection -> Shard 2 ✓
- NOTE: Changing ShardContext.set(3) HERE has NO EFFECT.
The connection is already bound to Shard 2.
4. @Transactional commits: COMMIT sent on Shard 2 connection
5. ShardingAspect clears ShardContext (ORDER=1 cleanup)
CONCLUSION:
- @Transactional is safe and works correctly on a SINGLE shard.
- ShardingAspect MUST run BEFORE @Transactional opens the connection.
That is why @Order(1) on ShardingAspect is non-negotiable.
- Changing ShardContext inside a @Transactional method does nothing.
The connection is already bound at transaction start.
5. Shard Key Resolution
ShardKeyResolver.java
@Component
public class ShardKeyResolver {
private final int numShards;
public ShardKeyResolver(@Value("${app.sharding.num-shards}") int numShards) {
this.numShards = numShards;
}
/**
* Resolve a long key to a shard number.
* Uses MurmurHash3 finalization mix for better distribution than Java's hashCode.
*
* Why not just key % numShards?
* Sequential user IDs (1, 2, 3, ...) would distribute evenly with simple modulo.
* But Snowflake IDs have a timestamp prefix — consecutive IDs have the same
* upper bits. Simple modulo on consecutive Snowflake IDs creates hot shards.
* MurmurHash3 mix spreads bits more uniformly.
*/
public int resolve(long key) {
long h = key;
h ^= (h >>> 33);
h *= 0xff51afd7ed558ccdL;
h ^= (h >>> 33);
h *= 0xc4ceb9fe1a85ec53L;
h ^= (h >>> 33);
return (int)(Math.abs(h) % numShards);
}
/**
* Resolve a String key (tenant ID, correlation ID).
*/
public int resolve(String key) {
if (key == null || key.isBlank()) {
throw new IllegalArgumentException("Shard key cannot be null or blank");
}
// Use the long version internally for consistent hashing
return resolve((long) key.hashCode());
}
public int getNumShards() {
return numShards;
}
}6. Custom Annotations and AOP: Automatic Routing
Without AOP, every service method would need boilerplate like:
// WITHOUT AOP: verbose, error-prone, easy to forget the finally
public Optional<User> getUserById(Long userId) {
ShardContext.set(shardKeyResolver.resolve(userId));
try {
return userRepository.findById(userId);
} finally {
ShardContext.clear(); // Easy to forget this!
}
}With AOP, you write:
// WITH AOP: clean, no boilerplate, impossible to forget the cleanup
@ShardedOperation(shardKeyParam = "userId")
public Optional<User> getUserById(Long userId) {
return userRepository.findById(userId); // AOP handles everything else
}ShardedOperation.java - Method-Level Annotation
/**
* Marks a service method as a sharded operation.
* The AOP aspect will extract the specified parameter's value,
* resolve the shard number, and set ShardContext before the method runs.
*
* Usage:
* @ShardedOperation(shardKeyParam = "userId")
* public User getUserById(Long userId) { ... }
*
* @ShardedOperation(shardKeyParam = "customerId")
* public List<Order> getOrdersForCustomer(Long customerId, int page) { ... }
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface ShardedOperation {
/**
* The name of the method parameter that contains the shard key.
* Must match the actual parameter name (requires -parameters compiler flag
* or use of @ShardKey parameter annotation instead).
*/
String shardKeyParam();
}ShardKey.java - Parameter-Level Annotation (Alternative)
/**
* Annotate the specific method parameter that is the shard key.
* This is the alternative to @ShardedOperation when you prefer
* marking the parameter rather than naming it in the method annotation.
*
* Usage:
* public User getUserById(@ShardKey Long userId) { ... }
* public Order createOrder(@ShardKey Long userId, OrderRequest request) { ... }
*/
@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface ShardKey {
}ShardingAspect.java - The AOP Interceptor
/**
* AOP Aspect that intercepts all @ShardedOperation and @ShardKey annotated methods.
* Automatically resolves the shard, sets ShardContext, runs the method, clears context.
*
* @Order(1): This MUST run before @Transactional (which runs at Integer.MAX_VALUE).
* If @Transactional runs first, it opens a DB connection BEFORE we set ShardContext,
* and the routing DataSource falls back to shard 0 for ALL transactional methods.
*/
@Aspect
@Component
@Order(1)
@Slf4j
public class ShardingAspect {
private final ShardKeyResolver shardKeyResolver;
public ShardingAspect(ShardKeyResolver shardKeyResolver) {
this.shardKeyResolver = shardKeyResolver;
}
/**
* Intercept methods annotated with @ShardedOperation.
* Extracts the named parameter, resolves its shard, sets context.
*/
@Around("@annotation(shardedOperation)")
public Object aroundShardedOperation(
ProceedingJoinPoint joinPoint,
ShardedOperation shardedOperation) throws Throwable {
String paramName = shardedOperation.shardKeyParam();
Object shardKeyValue = extractByParameterName(joinPoint, paramName);
if (shardKeyValue == null) {
throw new IllegalArgumentException(
"Shard key parameter '" + paramName + "' is null in method: " +
joinPoint.getSignature().toShortString()
);
}
int shardId = resolveShardId(shardKeyValue);
log.debug("@ShardedOperation resolved: method={}, shardKey={}, shard={}",
joinPoint.getSignature().getName(), shardKeyValue, shardId);
ShardContext.set(shardId);
try {
return joinPoint.proceed();
} finally {
ShardContext.clear();
}
}
/**
* Intercept methods that have a parameter annotated with @ShardKey.
* Scans parameter annotations to find which argument is the shard key.
*/
@Around("execution(* com.example.sharding.service..*(.., @com.example.sharding.annotation.ShardKey (*), ..))")
public Object aroundShardKeyParameter(ProceedingJoinPoint joinPoint) throws Throwable {
Object shardKeyValue = extractByParameterAnnotation(joinPoint);
if (shardKeyValue != null) {
int shardId = resolveShardId(shardKeyValue);
ShardContext.set(shardId);
try {
return joinPoint.proceed();
} finally {
ShardContext.clear();
}
}
// No @ShardKey found - proceed without setting context (will use default/fallback)
log.warn("Method {} has @ShardKey pointcut but no @ShardKey annotated parameter found",
joinPoint.getSignature().getName());
return joinPoint.proceed();
}
// --- Private helpers ---
private Object extractByParameterName(ProceedingJoinPoint joinPoint, String paramName) {
MethodSignature sig = (MethodSignature) joinPoint.getSignature();
String[] names = sig.getParameterNames(); // Requires -parameters compiler flag
Object[] args = joinPoint.getArgs();
if (names == null) {
throw new IllegalStateException(
"Parameter names not available. Add '-parameters' to javac compiler args " +
"or use @ShardKey annotation on the parameter instead of @ShardedOperation."
);
}
for (int i = 0; i < names.length; i++) {
if (names[i].equals(paramName)) {
return args[i];
}
}
throw new IllegalArgumentException(
"Parameter '" + paramName + "' not found in method " +
sig.getName() + ". Available params: " + Arrays.toString(names)
);
}
private Object extractByParameterAnnotation(ProceedingJoinPoint joinPoint) {
MethodSignature sig = (MethodSignature) joinPoint.getSignature();
Method method = sig.getMethod();
Annotation[][] paramAnnotations = method.getParameterAnnotations();
Object[] args = joinPoint.getArgs();
for (int i = 0; i < paramAnnotations.length; i++) {
for (Annotation ann : paramAnnotations[i]) {
if (ann instanceof ShardKey) {
return args[i];
}
}
}
return null;
}
private int resolveShardId(Object value) {
return switch (value) {
case Long l -> shardKeyResolver.resolve(l);
case Integer i -> shardKeyResolver.resolve((long) i);
case String s -> shardKeyResolver.resolve(s);
default -> throw new IllegalArgumentException(
"Unsupported shard key type: " + value.getClass().getName() +
". Supported types: Long, Integer, String"
);
};
}
}Enable Parameter Name Discovery (Required for @ShardedOperation)
Add this to pom.xml:
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<compilerArgs>
<arg>-parameters</arg> <!-- Makes parameter names available at runtime -->
</compilerArgs>
</configuration>
</plugin>
</plugins>
</build>Or in build.gradle:
tasks.withType(JavaCompile) {
options.compilerArgs << '-parameters'
}7. Entity Layer
Entities are completely standard JPA. No sharding-specific code.
The only sharding-relevant decision is the ID type: use Long (Snowflake), not @GeneratedValue.
User.java
@Entity
@Table(
name = "users",
indexes = {
@Index(name = "idx_users_email", columnList = "email"),
@Index(name = "idx_users_status", columnList = "status"),
@Index(name = "idx_users_created_at", columnList = "created_at")
}
)
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class User {
/**
* Snowflake ID. Assigned BEFORE persistence by SnowflakeIdGenerator.
* GenerationType.NONE: we provide the ID, database does NOT auto-generate it.
* This is MANDATORY for sharding - database auto_increment creates collisions.
*/
@Id
private Long id;
@Column(nullable = false, unique = true, length = 255)
private String email;
@Column(name = "full_name", nullable = false, length = 255)
private String fullName;
@Column(name = "phone_number", length = 20)
private String phoneNumber;
@Enumerated(EnumType.STRING)
@Column(nullable = false, length = 20)
private UserStatus status;
@Column(name = "created_at", nullable = false, updatable = false)
private Instant createdAt;
@Column(name = "updated_at")
private Instant updatedAt;
@PrePersist
void prePersist() {
if (createdAt == null) createdAt = Instant.now();
updatedAt = Instant.now();
}
@PreUpdate
void preUpdate() {
updatedAt = Instant.now();
}
}Order.java
@Entity
@Table(
name = "orders",
indexes = {
@Index(name = "idx_orders_user_id", columnList = "user_id"),
@Index(name = "idx_orders_status", columnList = "status"),
@Index(name = "idx_orders_created_at", columnList = "created_at")
}
)
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class Order {
@Id
private Long id; // Snowflake ID
/**
* user_id is the shard key for both User and Order tables.
* By sharding orders by user_id, all of a user's orders live on the same shard
* as the user themselves. The query "give me all orders for user X" = single shard.
*/
@Column(name = "user_id", nullable = false)
private Long userId;
@Column(name = "total_amount", nullable = false, precision = 12, scale = 2)
private BigDecimal totalAmount;
@Enumerated(EnumType.STRING)
@Column(nullable = false, length = 30)
private OrderStatus status;
@Column(name = "delivery_address", length = 500)
private String deliveryAddress;
@Column(name = "created_at", nullable = false, updatable = false)
private Instant createdAt;
@PrePersist
void prePersist() {
if (createdAt == null) createdAt = Instant.now();
}
}Flyway Migration: V1__create_tables.sql
This migration runs on EVERY shard via ShardedFlywayMigrationRunner (see Part 2).
-- Each shard has identical schema. Data differs, schema is same.
CREATE TABLE IF NOT EXISTS users (
id BIGINT PRIMARY KEY, -- Snowflake ID, no sequence
email VARCHAR(255) NOT NULL UNIQUE,
full_name VARCHAR(255) NOT NULL,
phone_number VARCHAR(20),
status VARCHAR(20) NOT NULL DEFAULT 'ACTIVE',
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX idx_users_email ON users(email);
CREATE INDEX idx_users_status ON users(status);
CREATE INDEX idx_users_created_at ON users(created_at);
CREATE TABLE IF NOT EXISTS orders (
id BIGINT PRIMARY KEY,
user_id BIGINT NOT NULL,
total_amount NUMERIC(12, 2) NOT NULL,
status VARCHAR(30) NOT NULL DEFAULT 'PENDING',
delivery_address VARCHAR(500),
created_at TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE INDEX idx_orders_user_id ON orders(user_id);
CREATE INDEX idx_orders_status ON orders(status);
CREATE INDEX idx_orders_created_at ON orders(created_at);
-- Outbox table on every shard for reliable event publishing
CREATE TABLE IF NOT EXISTS outbox_events (
id BIGINT PRIMARY KEY,
aggregate_id VARCHAR(100) NOT NULL,
aggregate_type VARCHAR(100) NOT NULL,
event_type VARCHAR(100) NOT NULL,
payload TEXT NOT NULL,
status VARCHAR(20) NOT NULL DEFAULT 'PENDING',
retry_count INT NOT NULL DEFAULT 0,
last_error VARCHAR(500),
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
published_at TIMESTAMPTZ
);
CREATE INDEX idx_outbox_status_created ON outbox_events(status, created_at);8. Repository Layer
Repositories are 100% standard Spring Data JPA. No sharding code here.
The routing happens transparently via ShardRoutingDataSource.
UserRepository.java
/**
* Standard Spring Data JPA repository.
* Works on whichever shard is currently in ShardContext when called.
*
* Do NOT call this from a controller directly.
* Always go through UserService which sets ShardContext via @ShardedOperation.
*/
@Repository
public interface UserRepository extends JpaRepository<User, Long> {
Optional<User> findByEmail(String email);
List<User> findByStatus(UserStatus status);
@Query("SELECT u FROM User u WHERE u.createdAt BETWEEN :from AND :to ORDER BY u.createdAt")
List<User> findUsersCreatedBetween(
@Param("from") Instant from,
@Param("to") Instant to
);
@Query("SELECT COUNT(u) FROM User u WHERE u.status = 'ACTIVE'")
long countActiveUsers();
// Paginated query for user's orders
@Query("SELECT u FROM User u WHERE u.status = :status ORDER BY u.createdAt DESC")
Page<User> findByStatusPaged(@Param("status") UserStatus status, Pageable pageable);
}OrderRepository.java
@Repository
public interface OrderRepository extends JpaRepository<Order, Long> {
List<Order> findByUserId(Long userId);
Page<Order> findByUserId(Long userId, Pageable pageable);
Optional<Order> findByIdAndUserId(Long id, Long userId);
List<Order> findByUserIdAndStatus(Long userId, OrderStatus status);
@Query("SELECT COUNT(o) FROM Order o WHERE o.status = 'PENDING'")
long countPendingOrders(); // Scatter-gather: must query all shards
@Query("SELECT SUM(o.totalAmount) FROM Order o WHERE o.userId = :userId")
Optional<BigDecimal> sumOrderAmountByUser(@Param("userId") Long userId);
}EmailIndexRepository.java - Non-Sharded Secondary Index
/**
* Secondary index database: maps email -> user_id.
* This repository uses a SEPARATE, non-sharded DataSource named "indexDataSource".
* It does NOT use the routingDataSource.
*
* Purpose: Enable O(1) email -> shard lookup without scatter-gather.
* Without this, "find user by email" requires querying all shards.
*/
@Repository
public interface EmailIndexRepository extends JpaRepository<EmailIndex, String> {
Optional<EmailIndex> findByEmail(String email);
void deleteByUserId(Long userId);
}
@Entity
@Table(name = "email_user_index")
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class EmailIndex {
@Id
@Column(length = 255)
private String email;
@Column(name = "user_id", nullable = false)
private Long userId;
// Optional optimization: cache the shard ID to skip resolver computation
@Column(name = "shard_id", nullable = false)
private int shardId;
@Column(name = "created_at", nullable = false)
private Instant createdAt;
}9. Service Layer: Putting It All Together
UserService.java
@Service
@Slf4j
public class UserService {
private final UserRepository userRepository;
private final OrderRepository orderRepository;
private final OutboxRepository outboxRepository;
private final EmailIndexRepository emailIndexRepository;
private final ShardContextHolder shardContextHolder;
private final ShardKeyResolver shardKeyResolver;
private final SnowflakeIdGenerator idGenerator;
private final ObjectMapper objectMapper;
// ============================================================
// READ OPERATIONS
// @ShardedOperation routes to correct shard before DB call
// @Transactional(readOnly=true) opens connection on correct shard
// ============================================================
@ShardedOperation(shardKeyParam = "userId")
@Transactional(readOnly = true)
public Optional<User> getUserById(Long userId) {
log.debug("Getting user {} from shard {}", userId, ShardContext.get());
return userRepository.findById(userId);
}
@ShardedOperation(shardKeyParam = "userId")
@Transactional(readOnly = true)
public Page<Order> getOrdersForUser(Long userId, Pageable pageable) {
return orderRepository.findByUserId(userId, pageable);
}
// ==========================================================<mark class="obsidian-highlight">
// LOOKUP BY EMAIL - Uses secondary index (no scatter-gather)
// </mark>==========================================================
public Optional<User> getUserByEmail(String email) {
// Step 1: Non-sharded lookup: email -> user_id (single fast query)
Optional<EmailIndex> emailIndex = emailIndexRepository
.findByEmail(email.toLowerCase().trim());
if (emailIndex.isEmpty()) {
return Optional.empty();
}
// Step 2: Use user_id to route to the correct shard
Long userId = emailIndex.get().getUserId();
return getUserById(userId); // @ShardedOperation routes correctly
}
// ==========================================================<mark class="obsidian-highlight">
// CREATE USER
// Writes to: (1) sharded DB, (2) outbox (same transaction)
// Outbox relay later updates email index DB asynchronously
// </mark>==========================================================
@Transactional // Covers only the shard DB write (single shard)
public User createUser(CreateUserRequest request) {
// Generate Snowflake ID BEFORE inserting - no DB sequence needed
Long userId = idGenerator.nextId();
int shardId = shardKeyResolver.resolve(userId);
// Set shard context manually here (not via @ShardedOperation)
// because we need it for the @Transactional scope
ShardContext.set(shardId);
try {
User user = User.builder()
.id(userId)
.email(request.getEmail().toLowerCase().trim())
.fullName(request.getFullName())
.phoneNumber(request.getPhoneNumber())
.status(UserStatus.ACTIVE)
.build();
User saved = userRepository.save(user);
// Write outbox event in the SAME @Transactional scope.
// If the DB commit succeeds, BOTH the user row AND the outbox row
// are written atomically. The outbox relay publishes the Kafka event.
String payload = objectMapper.writeValueAsString(
new UserCreatedPayload(userId, request.getEmail(), shardId)
);
OutboxEvent event = OutboxEvent.builder()
.id(idGenerator.nextId())
.aggregateId(String.valueOf(userId))
.aggregateType("User")
.eventType("UserCreated")
.payload(payload)
.status(OutboxStatus.PENDING)
.build();
outboxRepository.save(event);
log.info("User {} created on shard {}", userId, shardId);
return saved;
} catch (JsonProcessingException e) {
throw new RuntimeException("Failed to serialize outbox payload", e);
} finally {
ShardContext.clear();
}
}
// ==========================================================<mark class="obsidian-highlight">
// UPDATE USER - Single shard via @ShardedOperation
// </mark>==========================================================
@ShardedOperation(shardKeyParam = "userId")
@Transactional
public User updateUserProfile(Long userId, UpdateProfileRequest request) {
User user = userRepository.findById(userId)
.orElseThrow(() -> new UserNotFoundException(userId));
if (request.getFullName() != null) {
user.setFullName(request.getFullName());
}
if (request.getPhoneNumber() != null) {
user.setPhoneNumber(request.getPhoneNumber());
}
return userRepository.save(user);
// @Transactional commits on Shard X (determined by @ShardedOperation before this)
}
// ==========================================================<mark class="obsidian-highlight">
// DELETE USER - Must clean up across shard DB AND index DB
// </mark>==========================================================
@ShardedOperation(shardKeyParam = "userId")
@Transactional
public void deleteUser(Long userId) {
User user = userRepository.findById(userId)
.orElseThrow(() -> new UserNotFoundException(userId));
// Soft delete on the shard
user.setStatus(UserStatus.DELETED);
userRepository.save(user);
// Write outbox event to clean up email index (async)
String payload = "{\"userId\": " + userId + "}";
outboxRepository.save(OutboxEvent.builder()
.id(idGenerator.nextId())
.aggregateId(String.valueOf(userId))
.aggregateType("User")
.eventType("UserDeleted")
.payload(payload)
.status(OutboxStatus.PENDING)
.build());
// Both soft delete and outbox event committed atomically
}
}OrderService.java
@Service
@Slf4j
public class OrderService {
private final OrderRepository orderRepository;
private final UserRepository userRepository;
private final OutboxRepository outboxRepository;
private final ShardKeyResolver shardKeyResolver;
private final SnowflakeIdGenerator idGenerator;
private final ObjectMapper objectMapper;
/**
* Place an order.
*
* @Transactional(on shard X) covers:
* - Fetch user to validate (shard X)
* - Save order (shard X)
* - Save outbox event (shard X)
* All three in one atomic DB transaction on shard X.
*
* The outbox relay handles downstream: inventory, payment, notification.
*/
@ShardedOperation(shardKeyParam = "userId")
@Transactional
public Order placeOrder(Long userId, PlaceOrderRequest request) {
// Validate user on same shard
User user = userRepository.findById(userId)
.orElseThrow(() -> new UserNotFoundException(userId));
if (user.getStatus() != UserStatus.ACTIVE) {
throw new UserNotEligibleException(userId, user.getStatus());
}
Long orderId = idGenerator.nextId();
Order order = Order.builder()
.id(orderId)
.userId(userId)
.totalAmount(request.getTotalAmount())
.status(OrderStatus.PENDING)
.deliveryAddress(request.getDeliveryAddress())
.build();
Order saved = orderRepository.save(order);
// Outbox event triggers the downstream Saga (inventory, payment, notification)
try {
String payload = objectMapper.writeValueAsString(
new OrderCreatedPayload(orderId, userId, request)
);
outboxRepository.save(OutboxEvent.builder()
.id(idGenerator.nextId())
.aggregateId(String.valueOf(orderId))
.aggregateType("Order")
.eventType("OrderCreated")
.payload(payload)
.status(OutboxStatus.PENDING)
.build());
} catch (JsonProcessingException e) {
throw new RuntimeException("Failed to serialize order event", e);
}
log.info("Order {} placed for user {} on shard {}",
orderId, userId, ShardContext.get());
return saved;
}
/**
* Retrieve order by orderId.
* We need userId to route to the correct shard (orders are co-located with users).
*
* Alternative design: embed shard ID in the Snowflake order ID (see Part 2).
*/
public Optional<Order> getOrderById(Long orderId, Long userId) {
return shardContextHolder.withKey(userId,
() -> orderRepository.findByIdAndUserId(orderId, userId)
);
}
@ShardedOperation(shardKeyParam = "userId")
@Transactional
public Order cancelOrder(Long userId, Long orderId) {
Order order = orderRepository.findByIdAndUserId(orderId, userId)
.orElseThrow(() -> new OrderNotFoundException(orderId));
if (order.getStatus() == OrderStatus.DELIVERED) {
throw new OrderCancellationNotAllowedException(orderId);
}
order.setStatus(OrderStatus.CANCELLED);
Order saved = orderRepository.save(order);
// Trigger compensation saga: refund payment, release inventory
outboxRepository.save(OutboxEvent.builder()
.id(idGenerator.nextId())
.aggregateId(String.valueOf(orderId))
.aggregateType("Order")
.eventType("OrderCancelled")
.payload("{\"orderId\": " + orderId + ", \"userId\": " + userId + "}")
.status(OutboxStatus.PENDING)
.build());
return saved;
}
}10. REST Controller Layer
@RestController
@RequestMapping("/api/users")
@Slf4j
public class UserController {
private final UserService userService;
@PostMapping
@ResponseStatus(HttpStatus.CREATED)
public UserResponse createUser(@RequestBody @Valid CreateUserRequest request) {
User user = userService.createUser(request);
return toResponse(user);
}
/**
* userId in the path IS the shard key.
* UserService.getUserById is @ShardedOperation(shardKeyParam = "userId")
* so routing happens transparently.
*/
@GetMapping("/{userId}")
public ResponseEntity<UserResponse> getUserById(@PathVariable Long userId) {
return userService.getUserById(userId)
.map(u -> ResponseEntity.ok(toResponse(u)))
.orElse(ResponseEntity.notFound().build());
}
/**
* Email lookup: service routes via secondary index (no scatter-gather).
*/
@GetMapping("/lookup")
public ResponseEntity<UserResponse> getUserByEmail(@RequestParam String email) {
return userService.getUserByEmail(email)
.map(u -> ResponseEntity.ok(toResponse(u)))
.orElse(ResponseEntity.notFound().build());
}
@PutMapping("/{userId}")
public ResponseEntity<UserResponse> updateUser(
@PathVariable Long userId,
@RequestBody @Valid UpdateProfileRequest request) {
User user = userService.updateUserProfile(userId, request);
return ResponseEntity.ok(toResponse(user));
}
@DeleteMapping("/{userId}")
@ResponseStatus(HttpStatus.NO_CONTENT)
public void deleteUser(@PathVariable Long userId) {
userService.deleteUser(userId);
}
private UserResponse toResponse(User user) {
return UserResponse.builder()
.id(user.getId())
.email(user.getEmail())
.fullName(user.getFullName())
.status(user.getStatus())
.createdAt(user.getCreatedAt())
.build();
}
}11. Common Mistakes and How Spring Boot Exposes Them
Mistake 1: Calling Repository Without Setting ShardContext
// WRONG - AbstractRoutingDataSource falls back to shard 0.
// user_id 5042 is on shard 2. This query hits shard 0 and returns empty.
// The bug is silent: you get Optional.empty() instead of the user.
@GetMapping("/{userId}")
public ResponseEntity<User> getUser(@PathVariable Long userId) {
return userRepository.findById(userId) // Hits shard 0 always!
.map(ResponseEntity::ok)
.orElse(ResponseEntity.notFound().build()); // Always returns 404
}
// CORRECT - Go through the service which has @ShardedOperation
@GetMapping("/{userId}")
public ResponseEntity<User> getUser(@PathVariable Long userId) {
return userService.getUserById(userId) // @ShardedOperation routes correctly
.map(ResponseEntity::ok)
.orElse(ResponseEntity.notFound().build());
}Mistake 2: @Aspect @Order Is Wrong - @Transactional Runs First
// WRONG - ShardingAspect has default order (MAX_VALUE, same as @Transactional).
// @Transactional opens connection BEFORE ShardContext is set.
// The connection is opened on shard 0 (default). All writes go to shard 0.
@Aspect
@Component // Missing @Order(1) !!
public class ShardingAspect {
@Around("@annotation(shardedOperation)")
public Object aroundShardedOperation(...) {
ShardContext.set(shardId); // TOO LATE - @Transactional already opened connection
try {
return joinPoint.proceed();
} finally {
ShardContext.clear();
}
}
}
// CORRECT - @Order(1) ensures ShardingAspect runs BEFORE @Transactional
@Aspect
@Component
@Order(1) // This is non-negotiable
public class ShardingAspect { ... }Mistake 3: @Transactional Does NOT Protect Cross-Shard Operations
// WRONG - @Transactional opens ONE connection on ONE shard.
// Switching ShardContext inside a transaction does nothing.
// Both operations go to whichever shard was set at the transaction start.
@Transactional
public void transferMoney(Long fromUserId, Long toUserId, BigDecimal amount) {
ShardContext.set(shardKeyResolver.resolve(fromUserId));
accountRepository.debit(fromUserId, amount); // Shard A
ShardContext.set(shardKeyResolver.resolve(toUserId)); // NO EFFECT inside transaction
accountRepository.credit(toUserId, amount); // ALSO Shard A! (wrong)
}
// CORRECT - Two separate @Transactional calls, each on their own shard
// Saga pattern handles the "all or nothing" requirement
public void transferMoney(Long fromUserId, Long toUserId, BigDecimal amount) {
sagaOrchestrator.executeTransfer(fromUserId, toUserId, amount);
}Mistake 4: @Async Methods Lose ThreadLocal Context
// WRONG - @Async runs on a different thread. ThreadLocal is per-thread.
// The new thread has no ShardContext. Falls back to shard 0.
@Async
public CompletableFuture<Void> sendNotificationAsync(Long userId) {
User user = userRepository.findById(userId).orElseThrow(); // Shard 0! Wrong!
notificationService.send(user.getEmail(), "Welcome!");
return CompletableFuture.completedFuture(null);
}
// CORRECT - Resolve shard key before the async boundary.
// Pass enough information to re-set ShardContext on the new thread.
@Async
public CompletableFuture<Void> sendNotificationAsync(Long userId) {
return CompletableFuture.runAsync(() ->
shardContextHolder.withKey(userId, () -> {
User user = userRepository.findById(userId).orElseThrow();
notificationService.send(user.getEmail(), "Welcome!");
return null;
})
);
}Mistake 5: Not Using remove() to Clear ThreadLocal
// WRONG - set(null) leaves the ThreadLocal entry in the thread's ThreadLocalMap.
// Over thousands of requests, this causes a subtle memory leak.
ShardContext.set(null); // Leaves entry, just sets value to null
// CORRECT - remove() physically removes the entry from ThreadLocalMap.
ShardContext.clear(); // Calls CURRENT_SHARD.remove() internallyContinue to sharding-springboot-part2.md for: Snowflake ID Generation, Outbox Pattern, Scatter-Gather Service, Resilience4j Circuit Breaker, Flyway Multi-Shard Migrations, Health Indicators, Micrometer Metrics, and Integration Testing with Testcontainers.