diff --git a/include/net/sctp/sctp.h b/include/net/sctp/sctp.h
index 06b4f515e15727455c784c54bc572407ebbb5fa4..d7d8cba014697602832fe20e414b632104c9f239 100644
--- a/include/net/sctp/sctp.h
+++ b/include/net/sctp/sctp.h
@@ -127,7 +127,8 @@ int sctp_transport_lookup_process(int (*cb)(struct sctp_transport *, void *),
 				  const union sctp_addr *laddr,
 				  const union sctp_addr *paddr, void *p);
 int sctp_for_each_transport(int (*cb)(struct sctp_transport *, void *),
-			    struct net *net, int pos, void *p);
+			    int (*cb_done)(struct sctp_transport *, void *),
+			    struct net *net, int *pos, void *p);
 int sctp_for_each_endpoint(int (*cb)(struct sctp_endpoint *, void *), void *p);
 int sctp_get_sctp_info(struct sock *sk, struct sctp_association *asoc,
 		       struct sctp_info *info);
diff --git a/net/sctp/sctp_diag.c b/net/sctp/sctp_diag.c
index e99518e79b523e973d01e3f6bcfa2c9576c3fdd9..7008a992749bf62aa155c52a45e04609c43c1cac 100644
--- a/net/sctp/sctp_diag.c
+++ b/net/sctp/sctp_diag.c
@@ -279,9 +279,11 @@ static int sctp_tsp_dump_one(struct sctp_transport *tsp, void *p)
 	return err;
 }
 
-static int sctp_sock_dump(struct sock *sk, void *p)
+static int sctp_sock_dump(struct sctp_transport *tsp, void *p)
 {
+	struct sctp_endpoint *ep = tsp->asoc->ep;
 	struct sctp_comm_param *commp = p;
+	struct sock *sk = ep->base.sk;
 	struct sk_buff *skb = commp->skb;
 	struct netlink_callback *cb = commp->cb;
 	const struct inet_diag_req_v2 *r = commp->r;
@@ -289,9 +291,7 @@ static int sctp_sock_dump(struct sock *sk, void *p)
 	int err = 0;
 
 	lock_sock(sk);
-	if (!sctp_sk(sk)->ep)
-		goto release;
-	list_for_each_entry(assoc, &sctp_sk(sk)->ep->asocs, asocs) {
+	list_for_each_entry(assoc, &ep->asocs, asocs) {
 		if (cb->args[4] < cb->args[1])
 			goto next;
 
@@ -327,40 +327,30 @@ static int sctp_sock_dump(struct sock *sk, void *p)
 		cb->args[4]++;
 	}
 	cb->args[1] = 0;
-	cb->args[2]++;
 	cb->args[3] = 0;
 	cb->args[4] = 0;
 release:
 	release_sock(sk);
-	sock_put(sk);
 	return err;
 }
 
-static int sctp_get_sock(struct sctp_transport *tsp, void *p)
+static int sctp_sock_filter(struct sctp_transport *tsp, void *p)
 {
 	struct sctp_endpoint *ep = tsp->asoc->ep;
 	struct sctp_comm_param *commp = p;
 	struct sock *sk = ep->base.sk;
-	struct netlink_callback *cb = commp->cb;
 	const struct inet_diag_req_v2 *r = commp->r;
 	struct sctp_association *assoc =
 		list_entry(ep->asocs.next, struct sctp_association, asocs);
 
 	/* find the ep only once through the transports by this condition */
 	if (tsp->asoc != assoc)
-		goto out;
+		return 0;
 
 	if (r->sdiag_family != AF_UNSPEC && sk->sk_family != r->sdiag_family)
-		goto out;
-
-	sock_hold(sk);
-	cb->args[5] = (long)sk;
+		return 0;
 
 	return 1;
-
-out:
-	cb->args[2]++;
-	return 0;
 }
 
 static int sctp_ep_dump(struct sctp_endpoint *ep, void *p)
@@ -503,12 +493,8 @@ static void sctp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
 	if (!(idiag_states & ~(TCPF_LISTEN | TCPF_CLOSE)))
 		goto done;
 
-next:
-	cb->args[5] = 0;
-	sctp_for_each_transport(sctp_get_sock, net, cb->args[2], &commp);
-
-	if (cb->args[5] && !sctp_sock_dump((struct sock *)cb->args[5], &commp))
-		goto next;
+	sctp_for_each_transport(sctp_sock_filter, sctp_sock_dump,
+				net, (int *)&cb->args[2], &commp);
 
 done:
 	cb->args[1] = cb->args[4];
diff --git a/net/sctp/socket.c b/net/sctp/socket.c
index 1b00a1e09b93e4106a38b4f6d45df7175e27598b..d4730ada7f3233367be7a0e3bb10e286a25602c8 100644
--- a/net/sctp/socket.c
+++ b/net/sctp/socket.c
@@ -4658,29 +4658,39 @@ int sctp_transport_lookup_process(int (*cb)(struct sctp_transport *, void *),
 EXPORT_SYMBOL_GPL(sctp_transport_lookup_process);
 
 int sctp_for_each_transport(int (*cb)(struct sctp_transport *, void *),
-			    struct net *net, int pos, void *p) {
+			    int (*cb_done)(struct sctp_transport *, void *),
+			    struct net *net, int *pos, void *p) {
 	struct rhashtable_iter hti;
-	void *obj;
-	int err;
-
-	err = sctp_transport_walk_start(&hti);
-	if (err)
-		return err;
+	struct sctp_transport *tsp;
+	int ret;
 
-	obj = sctp_transport_get_idx(net, &hti, pos + 1);
-	for (; !IS_ERR_OR_NULL(obj); obj = sctp_transport_get_next(net, &hti)) {
-		struct sctp_transport *transport = obj;
+again:
+	ret = sctp_transport_walk_start(&hti);
+	if (ret)
+		return ret;
 
-		if (!sctp_transport_hold(transport))
+	tsp = sctp_transport_get_idx(net, &hti, *pos + 1);
+	for (; !IS_ERR_OR_NULL(tsp); tsp = sctp_transport_get_next(net, &hti)) {
+		if (!sctp_transport_hold(tsp))
 			continue;
-		err = cb(transport, p);
-		sctp_transport_put(transport);
-		if (err)
+		ret = cb(tsp, p);
+		if (ret)
 			break;
+		(*pos)++;
+		sctp_transport_put(tsp);
 	}
 	sctp_transport_walk_stop(&hti);
 
-	return err;
+	if (ret) {
+		if (cb_done && !cb_done(tsp, p)) {
+			(*pos)++;
+			sctp_transport_put(tsp);
+			goto again;
+		}
+		sctp_transport_put(tsp);
+	}
+
+	return ret;
 }
 EXPORT_SYMBOL_GPL(sctp_for_each_transport);