diff --git a/net/sched/sch_cake.c b/net/sched/sch_cake.c
index acc9b9da985f81ffd9b485e082cf1781e6731ba2..259d97bc2abd39df8df646c2ebc34ea272e1fd70 100644
--- a/net/sched/sch_cake.c
+++ b/net/sched/sch_cake.c
@@ -1517,16 +1517,27 @@ static unsigned int cake_drop(struct Qdisc *sch, struct sk_buff **to_free)
 
 static u8 cake_handle_diffserv(struct sk_buff *skb, u16 wash)
 {
+	int wlen = skb_network_offset(skb);
 	u8 dscp;
 
-	switch (skb->protocol) {
+	switch (tc_skb_protocol(skb)) {
 	case htons(ETH_P_IP):
+		wlen += sizeof(struct iphdr);
+		if (!pskb_may_pull(skb, wlen) ||
+		    skb_try_make_writable(skb, wlen))
+			return 0;
+
 		dscp = ipv4_get_dsfield(ip_hdr(skb)) >> 2;
 		if (wash && dscp)
 			ipv4_change_dsfield(ip_hdr(skb), INET_ECN_MASK, 0);
 		return dscp;
 
 	case htons(ETH_P_IPV6):
+		wlen += sizeof(struct ipv6hdr);
+		if (!pskb_may_pull(skb, wlen) ||
+		    skb_try_make_writable(skb, wlen))
+			return 0;
+
 		dscp = ipv6_get_dsfield(ipv6_hdr(skb)) >> 2;
 		if (wash && dscp)
 			ipv6_change_dsfield(ipv6_hdr(skb), INET_ECN_MASK, 0);