/* * Copyright 2018 The WebRTC Project Authors. All rights reserved. * * Use of this source code is governed by a BSD-style license * that can be found in the LICENSE file in the root of the source * tree. An additional intellectual property rights grant can be found * in the file PATENTS. All contributing project authors may * be found in the AUTHORS file in the root of the source tree. */ #include "p2p/base/mdns_message.h" #include "rtc_base/logging.h" #include "rtc_base/nethelpers.h" #include "rtc_base/stringencode.h" namespace webrtc { namespace { // RFC 1035, Section 4.1.1. // // QR bit. constexpr uint16_t kMDnsFlagMaskQueryOrResponse = 0x8000; // AA bit. constexpr uint16_t kMDnsFlagMaskAuthoritative = 0x0400; // RFC 1035, Section 4.1.2, QCLASS and RFC 6762, Section 18.12, repurposing of // top bit of QCLASS as the unicast response bit. constexpr uint16_t kMDnsQClassMaskUnicastResponse = 0x8000; constexpr size_t kMDnsHeaderSizeBytes = 12; bool ReadDomainName(MessageBufferReader* buf, std::string* name) { size_t name_start_pos = buf->CurrentOffset(); uint8_t label_length; if (!buf->ReadUInt8(&label_length)) { return false; } // RFC 1035, Section 4.1.4. // // If the first two bits of the length octet are ones, the name is compressed // and the rest six bits with the next octet denotes its position in the // message by the offset from the start of the message. auto is_pointer = [](uint8_t octet) { return (octet & 0x80) && (octet & 0x40); }; while (label_length && !is_pointer(label_length)) { // RFC 1035, Section 2.3.1, labels are restricted to 63 octets or less. if (label_length > 63) { return false; } std::string label; if (!buf->ReadString(&label, label_length)) { return false; } (*name) += label + "."; if (!buf->ReadUInt8(&label_length)) { return false; } } if (is_pointer(label_length)) { uint8_t next_octet; if (!buf->ReadUInt8(&next_octet)) { return false; } size_t pos_jump_to = ((label_length & 0x3f) << 8) | next_octet; // A legitimate pointer only refers to a prior occurrence of the same name, // and we should only move strictly backward to a prior name field after the // header. if (pos_jump_to >= name_start_pos || pos_jump_to < kMDnsHeaderSizeBytes) { return false; } MessageBufferReader new_buf(buf->MessageData(), buf->MessageLength()); if (!new_buf.Consume(pos_jump_to)) { return false; } return ReadDomainName(&new_buf, name); } return true; } void WriteDomainName(rtc::ByteBufferWriter* buf, const std::string& name) { std::vector labels; rtc::tokenize(name, '.', &labels); for (const auto& label : labels) { buf->WriteUInt8(label.length()); buf->WriteString(label); } buf->WriteUInt8(0); } } // namespace void MDnsHeader::SetQueryOrResponse(bool is_query) { if (is_query) { flags &= ~kMDnsFlagMaskQueryOrResponse; } else { flags |= kMDnsFlagMaskQueryOrResponse; } } void MDnsHeader::SetAuthoritative(bool is_authoritative) { if (is_authoritative) { flags |= kMDnsFlagMaskAuthoritative; } else { flags &= ~kMDnsFlagMaskAuthoritative; } } bool MDnsHeader::IsAuthoritative() const { return flags & kMDnsFlagMaskAuthoritative; } bool MDnsHeader::Read(MessageBufferReader* buf) { if (!buf->ReadUInt16(&id) || !buf->ReadUInt16(&flags) || !buf->ReadUInt16(&qdcount) || !buf->ReadUInt16(&ancount) || !buf->ReadUInt16(&nscount) || !buf->ReadUInt16(&arcount)) { RTC_LOG(LS_ERROR) << "Invalid mDNS header."; return false; } return true; } void MDnsHeader::Write(rtc::ByteBufferWriter* buf) const { buf->WriteUInt16(id); buf->WriteUInt16(flags); buf->WriteUInt16(qdcount); buf->WriteUInt16(ancount); buf->WriteUInt16(nscount); buf->WriteUInt16(arcount); } bool MDnsHeader::IsQuery() const { return !(flags & kMDnsFlagMaskQueryOrResponse); } MDnsSectionEntry::MDnsSectionEntry() = default; MDnsSectionEntry::~MDnsSectionEntry() = default; MDnsSectionEntry::MDnsSectionEntry(const MDnsSectionEntry& other) = default; void MDnsSectionEntry::SetType(SectionEntryType type) { switch (type) { case SectionEntryType::kA: type_ = 1; return; case SectionEntryType::kAAAA: type_ = 28; return; default: RTC_NOTREACHED(); } } SectionEntryType MDnsSectionEntry::GetType() const { switch (type_) { case 1: return SectionEntryType::kA; case 28: return SectionEntryType::kAAAA; default: return SectionEntryType::kUnsupported; } } void MDnsSectionEntry::SetClass(SectionEntryClass cls) { switch (cls) { case SectionEntryClass::kIN: class_ = 1; return; default: RTC_NOTREACHED(); } } SectionEntryClass MDnsSectionEntry::GetClass() const { switch (class_) { case 1: return SectionEntryClass::kIN; default: return SectionEntryClass::kUnsupported; } } MDnsQuestion::MDnsQuestion() = default; MDnsQuestion::MDnsQuestion(const MDnsQuestion& other) = default; MDnsQuestion::~MDnsQuestion() = default; bool MDnsQuestion::Read(MessageBufferReader* buf) { if (!ReadDomainName(buf, &name_)) { RTC_LOG(LS_ERROR) << "Invalid name."; return false; } if (!buf->ReadUInt16(&type_) || !buf->ReadUInt16(&class_)) { RTC_LOG(LS_ERROR) << "Invalid type and class."; return false; } return true; } bool MDnsQuestion::Write(rtc::ByteBufferWriter* buf) const { WriteDomainName(buf, name_); buf->WriteUInt16(type_); buf->WriteUInt16(class_); return true; } void MDnsQuestion::SetUnicastResponse(bool should_unicast) { if (should_unicast) { class_ |= kMDnsQClassMaskUnicastResponse; } else { class_ &= ~kMDnsQClassMaskUnicastResponse; } } bool MDnsQuestion::ShouldUnicastResponse() const { return class_ & kMDnsQClassMaskUnicastResponse; } MDnsResourceRecord::MDnsResourceRecord() = default; MDnsResourceRecord::MDnsResourceRecord(const MDnsResourceRecord& other) = default; MDnsResourceRecord::~MDnsResourceRecord() = default; bool MDnsResourceRecord::Read(MessageBufferReader* buf) { if (!ReadDomainName(buf, &name_)) { return false; } if (!buf->ReadUInt16(&type_) || !buf->ReadUInt16(&class_) || !buf->ReadUInt32(&ttl_seconds_) || !buf->ReadUInt16(&rdlength_)) { return false; } switch (GetType()) { case SectionEntryType::kA: return ReadARData(buf); case SectionEntryType::kAAAA: return ReadQuadARData(buf); case SectionEntryType::kUnsupported: return false; default: RTC_NOTREACHED(); } return false; } bool MDnsResourceRecord::ReadARData(MessageBufferReader* buf) { // A RDATA contains a 32-bit IPv4 address. return buf->ReadString(&rdata_, 4); } bool MDnsResourceRecord::ReadQuadARData(MessageBufferReader* buf) { // AAAA RDATA contains a 128-bit IPv6 address. return buf->ReadString(&rdata_, 16); } bool MDnsResourceRecord::Write(rtc::ByteBufferWriter* buf) const { WriteDomainName(buf, name_); buf->WriteUInt16(type_); buf->WriteUInt16(class_); buf->WriteUInt32(ttl_seconds_); buf->WriteUInt16(rdlength_); switch (GetType()) { case SectionEntryType::kA: WriteARData(buf); return true; case SectionEntryType::kAAAA: WriteQuadARData(buf); return true; case SectionEntryType::kUnsupported: return false; default: RTC_NOTREACHED(); } return true; } void MDnsResourceRecord::WriteARData(rtc::ByteBufferWriter* buf) const { buf->WriteString(rdata_); } void MDnsResourceRecord::WriteQuadARData(rtc::ByteBufferWriter* buf) const { buf->WriteString(rdata_); } bool MDnsResourceRecord::SetIPAddressInRecordData( const rtc::IPAddress& address) { int af = address.family(); if (af != AF_INET && af != AF_INET6) { return false; } char out[16] = {0}; if (!rtc::inet_pton(af, address.ToString().c_str(), out)) { return false; } rdlength_ = (af == AF_INET) ? 4 : 16; rdata_ = std::string(out, rdlength_); return true; } bool MDnsResourceRecord::GetIPAddressFromRecordData( rtc::IPAddress* address) const { if (GetType() != SectionEntryType::kA && GetType() != SectionEntryType::kAAAA) { return false; } if (rdata_.size() != 4 && rdata_.size() != 16) { return false; } char out[INET6_ADDRSTRLEN] = {0}; int af = (GetType() == SectionEntryType::kA) ? AF_INET : AF_INET6; if (!rtc::inet_ntop(af, rdata_.data(), out, sizeof(out))) { return false; } return rtc::IPFromString(std::string(out), address); } MDnsMessage::MDnsMessage() = default; MDnsMessage::~MDnsMessage() = default; bool MDnsMessage::Read(MessageBufferReader* buf) { RTC_DCHECK_EQ(0u, buf->CurrentOffset()); if (!header_.Read(buf)) { return false; } auto read_question = [&buf](std::vector* section, uint16_t count) { section->resize(count); for (auto& question : (*section)) { if (!question.Read(buf)) { return false; } } return true; }; auto read_rr = [&buf](std::vector* section, uint16_t count) { section->resize(count); for (auto& rr : (*section)) { if (!rr.Read(buf)) { return false; } } return true; }; if (!read_question(&question_section_, header_.qdcount) || !read_rr(&answer_section_, header_.ancount) || !read_rr(&authority_section_, header_.nscount) || !read_rr(&additional_section_, header_.arcount)) { return false; } return true; } bool MDnsMessage::Write(rtc::ByteBufferWriter* buf) const { header_.Write(buf); auto write_rr = [&buf](const std::vector& section) { for (auto rr : section) { if (!rr.Write(buf)) { return false; } } return true; }; for (auto question : question_section_) { if (!question.Write(buf)) { return false; } } if (!write_rr(answer_section_) || !write_rr(authority_section_) || !write_rr(additional_section_)) { return false; } return true; } bool MDnsMessage::ShouldUnicastResponse() const { bool should_unicast = false; for (const auto& question : question_section_) { should_unicast |= question.ShouldUnicastResponse(); } return should_unicast; } void MDnsMessage::AddQuestion(const MDnsQuestion& question) { question_section_.push_back(question); header_.qdcount = question_section_.size(); } void MDnsMessage::AddAnswerRecord(const MDnsResourceRecord& answer) { answer_section_.push_back(answer); header_.ancount = answer_section_.size(); } } // namespace webrtc