diff --git a/webrtc/p2p/base/turnport.cc b/webrtc/p2p/base/turnport.cc index f1f13d0cff..0255395012 100644 --- a/webrtc/p2p/base/turnport.cc +++ b/webrtc/p2p/base/turnport.cc @@ -533,11 +533,15 @@ int TurnPort::SendTo(const void* data, size_t size, return static_cast(size); } -void TurnPort::OnReadPacket( - rtc::AsyncPacketSocket* socket, const char* data, size_t size, - const rtc::SocketAddress& remote_addr, - const rtc::PacketTime& packet_time) { - ASSERT(socket == socket_); +bool TurnPort::HandleIncomingPacket(rtc::AsyncPacketSocket* socket, + const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + if (socket != socket_) { + // The packet was received on a shared socket after we've allocated a new + // socket for this TURN port. + return false; + } // This is to guard against a STUN response from previous server after // alternative server redirection. TODO(guoweis): add a unit test for this @@ -547,19 +551,19 @@ void TurnPort::OnReadPacket( << remote_addr.ToString() << ", server_address_:" << server_address_.address.ToString(); - return; + return false; } // The message must be at least the size of a channel header. if (size < TURN_CHANNEL_HEADER_SIZE) { LOG_J(LS_WARNING, this) << "Received TURN message that was too short"; - return; + return false; } if (state_ == STATE_DISCONNECTED) { LOG_J(LS_WARNING, this) << "Received TURN message while the Turn port is disconnected"; - return; + return false; } // Check the message type, to see if is a Channel Data message. @@ -568,27 +572,41 @@ void TurnPort::OnReadPacket( uint16_t msg_type = rtc::GetBE16(data); if (IsTurnChannelData(msg_type)) { HandleChannelData(msg_type, data, size, packet_time); - } else if (msg_type == TURN_DATA_INDICATION) { - HandleDataIndication(data, size, packet_time); - } else { - if (SharedSocket() && - (msg_type == STUN_BINDING_RESPONSE || - msg_type == STUN_BINDING_ERROR_RESPONSE)) { - LOG_J(LS_VERBOSE, this) << - "Ignoring STUN binding response message on shared socket."; - return; - } + return true; - // This must be a response for one of our requests. - // Check success responses, but not errors, for MESSAGE-INTEGRITY. - if (IsStunSuccessResponseType(msg_type) && - !StunMessage::ValidateMessageIntegrity(data, size, hash())) { - LOG_J(LS_WARNING, this) << "Received TURN message with invalid " - << "message integrity, msg_type=" << msg_type; - return; - } - request_manager_.CheckResponse(data, size); } + + if (msg_type == TURN_DATA_INDICATION) { + HandleDataIndication(data, size, packet_time); + return true; + } + + if (SharedSocket() && (msg_type == STUN_BINDING_RESPONSE || + msg_type == STUN_BINDING_ERROR_RESPONSE)) { + LOG_J(LS_VERBOSE, this) << + "Ignoring STUN binding response message on shared socket."; + return false; + } + + // This must be a response for one of our requests. + // Check success responses, but not errors, for MESSAGE-INTEGRITY. + if (IsStunSuccessResponseType(msg_type) && + !StunMessage::ValidateMessageIntegrity(data, size, hash())) { + LOG_J(LS_WARNING, this) << "Received TURN message with invalid " + << "message integrity, msg_type=" << msg_type; + return true; + } + request_manager_.CheckResponse(data, size); + + return true; +} + +void TurnPort::OnReadPacket(rtc::AsyncPacketSocket* socket, + const char* data, + size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time) { + HandleIncomingPacket(socket, data, size, remote_addr, packet_time); } void TurnPort::OnSentPacket(rtc::AsyncPacketSocket* socket, diff --git a/webrtc/p2p/base/turnport.h b/webrtc/p2p/base/turnport.h index 797fa3f94f..461fc1304d 100644 --- a/webrtc/p2p/base/turnport.h +++ b/webrtc/p2p/base/turnport.h @@ -94,13 +94,10 @@ class TurnPort : public Port { virtual int GetOption(rtc::Socket::Option opt, int* value); virtual int GetError(); - virtual bool HandleIncomingPacket( - rtc::AsyncPacketSocket* socket, const char* data, size_t size, - const rtc::SocketAddress& remote_addr, - const rtc::PacketTime& packet_time) { - OnReadPacket(socket, data, size, remote_addr, packet_time); - return true; - } + virtual bool HandleIncomingPacket(rtc::AsyncPacketSocket* socket, + const char* data, size_t size, + const rtc::SocketAddress& remote_addr, + const rtc::PacketTime& packet_time); virtual void OnReadPacket(rtc::AsyncPacketSocket* socket, const char* data, size_t size, const rtc::SocketAddress& remote_addr, diff --git a/webrtc/p2p/base/turnport_unittest.cc b/webrtc/p2p/base/turnport_unittest.cc index a8f2d98188..15a7954d6d 100644 --- a/webrtc/p2p/base/turnport_unittest.cc +++ b/webrtc/p2p/base/turnport_unittest.cc @@ -666,6 +666,13 @@ TEST_F(TurnPortTest, TestTurnAllocateMismatch) { // Verifies that the new port has a different address now. EXPECT_NE(first_addr, turn_port_->socket()->GetLocalAddress()); + + // Verify that all packets received from the shared socket are ignored. + std::string test_packet = "Test packet"; + EXPECT_FALSE(turn_port_->HandleIncomingPacket( + socket_.get(), test_packet.data(), test_packet.size(), + rtc::SocketAddress(kTurnUdpExtAddr.ipaddr(), 0), + rtc::CreatePacketTime(0))); } // Tests that a shared-socket-TurnPort creates its own socket after diff --git a/webrtc/p2p/client/basicportallocator.cc b/webrtc/p2p/client/basicportallocator.cc index 22abf33e93..edfd0219d0 100644 --- a/webrtc/p2p/client/basicportallocator.cc +++ b/webrtc/p2p/client/basicportallocator.cc @@ -1080,13 +1080,13 @@ void AllocationSequence::OnReadPacket( // a STUN binding response, so we pass the message to TurnPort regardless of // the message type. The TurnPort will just ignore the message since it will // not find any request by transaction ID. - for (std::vector::const_iterator it = turn_ports_.begin(); - it != turn_ports_.end(); ++it) { - TurnPort* port = *it; + for (TurnPort* port : turn_ports_) { if (port->server_address().address == remote_addr) { - port->HandleIncomingPacket(socket, data, size, remote_addr, packet_time); + if (port->HandleIncomingPacket(socket, data, size, remote_addr, + packet_time)) { + return; + } turn_port_found = true; - break; } } @@ -1097,8 +1097,9 @@ void AllocationSequence::OnReadPacket( // the TURN server is also a STUN server. if (!turn_port_found || stun_servers.find(remote_addr) != stun_servers.end()) { - udp_port_->HandleIncomingPacket( - socket, data, size, remote_addr, packet_time); + RTC_DCHECK(udp_port_->SharedSocket()); + udp_port_->HandleIncomingPacket(socket, data, size, remote_addr, + packet_time); } } }