The macro CONN_OR_NULL() is used to look up connections by index with bounds checking. Replace it with an inline function, which means: - Better type checking - No danger of multiple evaluation of an @index with side effects Also add a helper to perform the reverse translation: from connection pointer to index. Introduce a macro for this which will make later cleanups easier and safer. Signed-off-by: David Gibson <david(a)gibson.dropbear.id.au> --- tcp.c | 83 ++++++++++++++++++++++++++++++++--------------------------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/tcp.c b/tcp.c index d043123..34d7d45 100644 --- a/tcp.c +++ b/tcp.c @@ -518,14 +518,6 @@ struct tcp_conn { (conn->events & (SOCK_FIN_RCVD | TAP_FIN_RCVD))) #define CONN_HAS(conn, set) ((conn->events & (set)) == (set)) -#define CONN(index) (tc + (index)) - -/* We probably don't want to use gcc statement expressions (for portability), so - * use this only after well-defined sequence points (no pre-/post-increments). - */ -#define CONN_OR_NULL(index) \ - (((int)(index) >= 0 && (index) < TCP_MAX_CONNS) ? (tc + (index)) : NULL) - static const char *tcp_event_str[] __attribute((__unused__)) = { "SOCK_ACCEPTED", "TAP_SYN_RCVD", "ESTABLISHED", "TAP_SYN_ACK_SENT", @@ -705,6 +697,21 @@ static size_t tcp6_l2_flags_buf_bytes; /* TCP connections */ static struct tcp_conn tc[TCP_MAX_CONNS]; +#define CONN(index) (tc + (index)) +#define CONN_IDX(conn) ((conn) - tc) + +/** conn_at_idx() - Find a connection by index, if present + * @index: Index of connection to lookup + * + * Return: pointer to connection, or NULL if @index is out of bounds + */ +static inline struct tcp_conn *conn_at_idx(int index) +{ + if ((index < 0) || (index >= TCP_MAX_CONNS)) + return NULL; + return CONN(index); +} + /* Table for lookup from remote address, local port, remote port */ static struct tcp_conn *tc_hash[TCP_HASH_TABLE_SIZE]; @@ -761,7 +768,7 @@ static int tcp_epoll_ctl(const struct ctx *c, struct tcp_conn *conn) { int m = (conn->flags & IN_EPOLL) ? EPOLL_CTL_MOD : EPOLL_CTL_ADD; union epoll_ref ref = { .r.proto = IPPROTO_TCP, .r.s = conn->sock, - .r.p.tcp.tcp.index = conn - tc, + .r.p.tcp.tcp.index = CONN_IDX(conn), .r.p.tcp.tcp.v6 = CONN_V6(conn) }; struct epoll_event ev = { .data.u64 = ref.u64 }; @@ -784,7 +791,7 @@ static int tcp_epoll_ctl(const struct ctx *c, struct tcp_conn *conn) union epoll_ref ref_t = { .r.proto = IPPROTO_TCP, .r.s = conn->sock, .r.p.tcp.tcp.timer = 1, - .r.p.tcp.tcp.index = conn - tc }; + .r.p.tcp.tcp.index = CONN_IDX(conn) }; struct epoll_event ev_t = { .data.u64 = ref_t.u64, .events = EPOLLIN | EPOLLET }; @@ -813,7 +820,7 @@ static void tcp_timer_ctl(const struct ctx *c, struct tcp_conn *conn) union epoll_ref ref = { .r.proto = IPPROTO_TCP, .r.s = conn->sock, .r.p.tcp.tcp.timer = 1, - .r.p.tcp.tcp.index = conn - tc }; + .r.p.tcp.tcp.index = CONN_IDX(conn) }; struct epoll_event ev = { .data.u64 = ref.u64, .events = EPOLLIN | EPOLLET }; int fd; @@ -846,7 +853,7 @@ static void tcp_timer_ctl(const struct ctx *c, struct tcp_conn *conn) it.it_value.tv_sec = ACT_TIMEOUT; } - debug("TCP: index %li, timer expires in %lu.%03lus", conn - tc, + debug("TCP: index %li, timer expires in %lu.%03lus", CONN_IDX(conn), it.it_value.tv_sec, it.it_value.tv_nsec / 1000 / 1000); timerfd_settime(conn->timer, 0, &it, NULL); @@ -867,7 +874,7 @@ static void conn_flag_do(const struct ctx *c, struct tcp_conn *conn, conn->flags &= flag; if (fls(~flag) >= 0) { - debug("TCP: index %li: %s dropped", conn - tc, + debug("TCP: index %li: %s dropped", CONN_IDX(conn), tcp_flag_str[fls(~flag)]); } } else { @@ -876,7 +883,7 @@ static void conn_flag_do(const struct ctx *c, struct tcp_conn *conn, conn->flags |= flag; if (fls(flag) >= 0) { - debug("TCP: index %li: %s", conn - tc, + debug("TCP: index %li: %s", CONN_IDX(conn), tcp_flag_str[fls(flag)]); } } @@ -926,12 +933,12 @@ static void conn_event_do(const struct ctx *c, struct tcp_conn *conn, new += 5; if (prev != new) { - debug("TCP: index %li, %s: %s -> %s", conn - tc, + debug("TCP: index %li, %s: %s -> %s", CONN_IDX(conn), num == -1 ? "CLOSED" : tcp_event_str[num], prev == -1 ? "CLOSED" : tcp_state_str[prev], (new == -1 || num == -1) ? "CLOSED" : tcp_state_str[new]); } else { - debug("TCP: index %li, %s", conn - tc, + debug("TCP: index %li, %s", CONN_IDX(conn), num == -1 ? "CLOSED" : tcp_event_str[num]); } @@ -1355,12 +1362,12 @@ static void tcp_hash_insert(const struct ctx *c, struct tcp_conn *conn, int b; b = tcp_hash(c, af, addr, conn->tap_port, conn->sock_port); - conn->next_index = tc_hash[b] ? tc_hash[b] - tc : -1; + conn->next_index = tc_hash[b] ? CONN_IDX(tc_hash[b]) : -1; tc_hash[b] = conn; conn->hash_bucket = b; debug("TCP: hash table insert: index %li, sock %i, bucket: %i, next: " - "%p", conn - tc, conn->sock, b, CONN_OR_NULL(conn->next_index)); + "%p", CONN_IDX(conn), conn->sock, b, conn_at_idx(conn->next_index)); } /** @@ -1373,19 +1380,19 @@ static void tcp_hash_remove(const struct tcp_conn *conn) int b = conn->hash_bucket; for (entry = tc_hash[b]; entry; - prev = entry, entry = CONN_OR_NULL(entry->next_index)) { + prev = entry, entry = conn_at_idx(entry->next_index)) { if (entry == conn) { if (prev) prev->next_index = conn->next_index; else - tc_hash[b] = CONN_OR_NULL(conn->next_index); + tc_hash[b] = conn_at_idx(conn->next_index); break; } } debug("TCP: hash table remove: index %li, sock %i, bucket: %i, new: %p", - conn - tc, conn->sock, b, - prev ? CONN_OR_NULL(prev->next_index) : tc_hash[b]); + CONN_IDX(conn), conn->sock, b, + prev ? conn_at_idx(prev->next_index) : tc_hash[b]); } /** @@ -1399,10 +1406,10 @@ static void tcp_hash_update(struct tcp_conn *old, struct tcp_conn *new) int b = old->hash_bucket; for (entry = tc_hash[b]; entry; - prev = entry, entry = CONN_OR_NULL(entry->next_index)) { + prev = entry, entry = conn_at_idx(entry->next_index)) { if (entry == old) { if (prev) - prev->next_index = new - tc; + prev->next_index = CONN_IDX(new); else tc_hash[b] = new; break; @@ -1411,7 +1418,7 @@ static void tcp_hash_update(struct tcp_conn *old, struct tcp_conn *new) debug("TCP: hash table update: old index %li, new index %li, sock %i, " "bucket: %i, old: %p, new: %p", - old - tc, new - tc, new->sock, b, old, new); + CONN_IDX(old), CONN_IDX(new), new->sock, b, old, new); } /** @@ -1431,7 +1438,7 @@ static struct tcp_conn *tcp_hash_lookup(const struct ctx *c, int af, int b = tcp_hash(c, af, addr, tap_port, sock_port); struct tcp_conn *conn; - for (conn = tc_hash[b]; conn; conn = CONN_OR_NULL(conn->next_index)) { + for (conn = tc_hash[b]; conn; conn = conn_at_idx(conn->next_index)) { if (tcp_hash_match(conn, af, addr, tap_port, sock_port)) return conn; } @@ -1448,9 +1455,9 @@ static void tcp_table_compact(struct ctx *c, struct tcp_conn *hole) { struct tcp_conn *from, *to; - if ((hole - tc) == --c->tcp.conn_count) { + if (CONN_IDX(hole) == --c->tcp.conn_count) { debug("TCP: hash table compaction: maximum index was %li (%p)", - hole - tc, hole); + CONN_IDX(hole), hole); memset(hole, 0, sizeof(*hole)); return; } @@ -1465,7 +1472,7 @@ static void tcp_table_compact(struct ctx *c, struct tcp_conn *hole) debug("TCP: hash table compaction: old index %li, new index %li, " "sock %i, from: %p, to: %p", - from - tc, to - tc, from->sock, from, to); + CONN_IDX(from), CONN_IDX(to), from->sock, from, to); memset(from, 0, sizeof(*from)); } @@ -1488,7 +1495,7 @@ static void tcp_conn_destroy(struct ctx *c, struct tcp_conn *conn) static void tcp_rst_do(struct ctx *c, struct tcp_conn *conn); #define tcp_rst(c, conn) \ do { \ - debug("TCP: index %li, reset at %s:%i", conn - tc, \ + debug("TCP: index %li, reset at %s:%i", CONN_IDX(conn), \ __func__, __LINE__); \ tcp_rst_do(c, conn); \ } while (0) @@ -2734,7 +2741,7 @@ int tcp_tap_handler(struct ctx *c, int af, const void *addr, return 1; } - trace("TCP: packet length %lu from tap for index %lu", len, conn - tc); + trace("TCP: packet length %lu from tap for index %lu", len, CONN_IDX(conn)); if (th->rst) { conn_event(c, conn, CLOSED); @@ -2942,7 +2949,7 @@ static void tcp_conn_from_sock(struct ctx *c, union epoll_ref ref, */ static void tcp_timer_handler(struct ctx *c, union epoll_ref ref) { - struct tcp_conn *conn = CONN_OR_NULL(ref.r.p.tcp.tcp.index); + struct tcp_conn *conn = conn_at_idx(ref.r.p.tcp.tcp.index); struct itimerspec check_armed = { { 0 }, { 0 } }; if (!conn) @@ -2961,17 +2968,17 @@ static void tcp_timer_handler(struct ctx *c, union epoll_ref ref) conn_flag(c, conn, ~ACK_TO_TAP_DUE); } else if (conn->flags & ACK_FROM_TAP_DUE) { if (!(conn->events & ESTABLISHED)) { - debug("TCP: index %li, handshake timeout", conn - tc); + debug("TCP: index %li, handshake timeout", CONN_IDX(conn)); tcp_rst(c, conn); } else if (CONN_HAS(conn, SOCK_FIN_SENT | TAP_FIN_ACKED)) { - debug("TCP: index %li, FIN timeout", conn - tc); + debug("TCP: index %li, FIN timeout", CONN_IDX(conn)); tcp_rst(c, conn); } else if (conn->retrans == TCP_MAX_RETRANS) { debug("TCP: index %li, retransmissions count exceeded", - conn - tc); + CONN_IDX(conn)); tcp_rst(c, conn); } else { - debug("TCP: index %li, ACK timeout, retry", conn - tc); + debug("TCP: index %li, ACK timeout, retry", CONN_IDX(conn)); conn->retrans++; conn->seq_to_tap = conn->seq_ack_from_tap; tcp_data_from_sock(c, conn); @@ -2989,7 +2996,7 @@ static void tcp_timer_handler(struct ctx *c, union epoll_ref ref) */ timerfd_settime(conn->timer, 0, &new, &old); if (old.it_value.tv_sec == ACT_TIMEOUT) { - debug("TCP: index %li, activity timeout", conn - tc); + debug("TCP: index %li, activity timeout", CONN_IDX(conn)); tcp_rst(c, conn); } } @@ -3022,7 +3029,7 @@ void tcp_sock_handler(struct ctx *c, union epoll_ref ref, uint32_t events, return; } - if (!(conn = CONN_OR_NULL(ref.r.p.tcp.tcp.index))) + if (!(conn = conn_at_idx(ref.r.p.tcp.tcp.index))) return; if (conn->events == CLOSED) -- 2.38.1