The only reason we need separate functions for the IPv4 and IPv6 case is
to calculate the checksum of the IP pseudo-header, which is different for
the two cases. However, the caller already knows which path it's on and
can access the values needed for the pseudo-header partial sum more easily
than tcp_update_check_tcp[46]() can.
So, merge these functions into a single tcp_update_csum() function that
just takes the pseudo-header partial sum, calculated in the caller.
Signed-off-by: David Gibson
---
tcp.c | 65 +++++++++++++++++---------------------------------
tcp_internal.h | 9 +++----
tcp_vu.c | 18 +++++++++-----
3 files changed, 37 insertions(+), 55 deletions(-)
diff --git a/tcp.c b/tcp.c
index 5e26243..297eb8c 100644
--- a/tcp.c
+++ b/tcp.c
@@ -753,50 +753,20 @@ static void tcp_sock_set_bufsize(const struct ctx *c, int s)
}
/**
- * tcp_update_check_tcp4() - Calculate TCP checksum for IPv4
- * @iph: IPv4 header
+ * tcp_update_csum() - Calculate TCP checksum
+ * @psum: Unfolded partial checksum of the IPv4 or IPv6 pseudo-header
* @th: TCP header (updated)
* @iov: IO vector containing the TCP payload
* @iov_cnt: Length of @iov
* @doffset: TCP payload offset in @iov
*/
-void tcp_update_check_tcp4(const struct iphdr *iph, struct tcphdr *th,
- const struct iovec *iov, int iov_cnt,
- size_t doffset)
+void tcp_update_csum(uint32_t psum, struct tcphdr *th,
+ const struct iovec *iov, int iov_cnt,
+ size_t doffset)
{
- uint16_t l4len = ntohs(iph->tot_len) - sizeof(struct iphdr);
- struct in_addr saddr = { .s_addr = iph->saddr };
- struct in_addr daddr = { .s_addr = iph->daddr };
- uint32_t sum;
-
- sum = proto_ipv4_header_psum(l4len, IPPROTO_TCP, saddr, daddr);
-
- th->check = 0;
- sum = csum_unfolded(th, sizeof(*th), sum);
- th->check = csum_iov(iov, iov_cnt, doffset, sum);
-}
-
-/**
- * tcp_update_check_tcp6() - Calculate TCP checksum for IPv6
- * @ip6h: IPv6 header
- * @th: TCP header (updated)
- * @iov: IO vector containing the TCP payload
- * @iov_cnt: Length of @iov
- * @doffset: TCP payload offset in @iov
- */
-void tcp_update_check_tcp6(const struct ipv6hdr *ip6h, struct tcphdr *th,
- const struct iovec *iov, int iov_cnt,
- size_t doffset)
-{
- uint16_t l4len = ntohs(ip6h->payload_len);
- uint32_t sum;
-
- sum = proto_ipv6_header_psum(l4len, IPPROTO_TCP, &ip6h->saddr,
- &ip6h->daddr);
-
th->check = 0;
- sum = csum_unfolded(th, sizeof(*th), sum);
- th->check = csum_iov(iov, iov_cnt, doffset, sum);
+ psum = csum_unfolded(th, sizeof(*th), psum);
+ th->check = csum_iov(iov, iov_cnt, doffset, psum);
}
/**
@@ -948,10 +918,14 @@ void tcp_fill_headers4(const struct tcp_tap_conn *conn,
tcp_fill_header(th, conn, seq);
- if (no_tcp_csum)
+ if (no_tcp_csum) {
th->check = 0;
- else
- tcp_update_check_tcp4(iph, th, iov, iov_cnt, doffset);
+ } else {
+ uint32_t psum = proto_ipv4_header_psum(l4len, IPPROTO_TCP,
+ *src4, *dst4);
+
+ tcp_update_csum(psum, th, iov, iov_cnt, doffset);
+ }
tap_hdr_update(taph, l3len + sizeof(struct ethhdr));
}
@@ -993,10 +967,15 @@ void tcp_fill_headers6(const struct tcp_tap_conn *conn,
tcp_fill_header(th, conn, seq);
- if (no_tcp_csum)
+ if (no_tcp_csum) {
th->check = 0;
- else
- tcp_update_check_tcp6(ip6h, th, iov, iov_cnt, doffset);
+ } else {
+ uint32_t psum = proto_ipv6_header_psum(l4len, IPPROTO_TCP,
+ &ip6h->saddr,
+ &ip6h->daddr);
+
+ tcp_update_csum(psum, th, iov, iov_cnt, doffset);
+ }
tap_hdr_update(taph, l4len + sizeof(*ip6h) + sizeof(struct ethhdr));
}
diff --git a/tcp_internal.h b/tcp_internal.h
index 8f9267c..a2de15a 100644
--- a/tcp_internal.h
+++ b/tcp_internal.h
@@ -177,12 +177,9 @@ void tcp_rst_do(const struct ctx *c, struct tcp_tap_conn *conn);
struct tcp_info_linux;
-void tcp_update_check_tcp4(const struct iphdr *iph, struct tcphdr *th,
- const struct iovec *iov, int iov_cnt,
- size_t doffset);
-void tcp_update_check_tcp6(const struct ipv6hdr *ip6h, struct tcphdr *th,
- const struct iovec *iov, int iov_cnt,
- size_t doffset);
+void tcp_update_csum(uint32_t psum, struct tcphdr *th,
+ const struct iovec *iov, int iov_cnt,
+ size_t doffset);
void tcp_fill_headers4(const struct tcp_tap_conn *conn,
struct tap_hdr *taph, struct iphdr *iph,
struct tcphdr *th,
diff --git a/tcp_vu.c b/tcp_vu.c
index bf45c74..916e35d 100644
--- a/tcp_vu.c
+++ b/tcp_vu.c
@@ -67,20 +67,26 @@ static void tcp_vu_update_check(const struct flowside *tapside,
struct iovec *iov, int iov_used)
{
char *base = iov[0].iov_base;
+ struct tcphdr *th;
+ uint32_t psum;
if (inany_v4(&tapside->oaddr)) {
- struct tcphdr *th = vu_payloadv4(base);
+ const struct in_addr *src4 = inany_v4(&tapside->oaddr);
+ const struct in_addr *dst4 = inany_v4(&tapside->eaddr);
const struct iphdr *iph = vu_ip(base);
+ size_t l4len = ntohs(iph->tot_len) - sizeof(*th);
- tcp_update_check_tcp4(iph, th, iov, iov_used,
- (char *)(th + 1) - base);
+ th = vu_payloadv4(base);
+ psum = proto_ipv4_header_psum(l4len, IPPROTO_TCP, *src4, *dst4);
} else {
- struct tcphdr *th = vu_payloadv6(base);
const struct ipv6hdr *ip6h = vu_ip(base);
+ size_t l4len = ntohs(ip6h->payload_len);
- tcp_update_check_tcp6(ip6h, th, iov, iov_used,
- (char *)(th + 1) - base);
+ th = vu_payloadv6(base);
+ psum = proto_ipv6_header_psum(l4len, IPPROTO_TCP,
+ &ip6h->saddr, &ip6h->daddr);
}
+ tcp_update_csum(psum, th, iov, iov_used, (char *)(th + 1) - base);
}
/**
--
2.47.0