diff --git a/mm/slab.c b/mm/slab.c
index 95f344d794536fa023ad4b0483bcd685c27519ee..1857a652c92800f9a493c123f79dfcd8523d9ae7 100644
--- a/mm/slab.c
+++ b/mm/slab.c
@@ -3144,14 +3144,10 @@ slab_alloc_node(struct kmem_cache *cachep, gfp_t flags, int nodeid,
 	int slab_node = numa_mem_id();
 
 	flags &= gfp_allowed_mask;
-
-	lockdep_trace_alloc(flags);
-
-	if (should_failslab(cachep, flags))
+	cachep = slab_pre_alloc_hook(cachep, flags);
+	if (unlikely(!cachep))
 		return NULL;
 
-	cachep = memcg_kmem_get_cache(cachep, flags);
-
 	cache_alloc_debugcheck_before(cachep, flags);
 	local_irq_save(save_flags);
 
@@ -3232,14 +3228,10 @@ slab_alloc(struct kmem_cache *cachep, gfp_t flags, unsigned long caller)
 	void *objp;
 
 	flags &= gfp_allowed_mask;
-
-	lockdep_trace_alloc(flags);
-
-	if (should_failslab(cachep, flags))
+	cachep = slab_pre_alloc_hook(cachep, flags);
+	if (unlikely(!cachep))
 		return NULL;
 
-	cachep = memcg_kmem_get_cache(cachep, flags);
-
 	cache_alloc_debugcheck_before(cachep, flags);
 	local_irq_save(save_flags);
 	objp = __do_cache_alloc(cachep, flags);