fix off by one error when removing ricochet instances #11

Closed
dan wants to merge 1 commits from dan/libricochet-go:master into master
3 changed files with 300 additions and 381 deletions

View File

@ -98,147 +98,175 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService
return return
} }
packets, err := r.rni.RecvRicochetPackets(oc.conn) packet, err := r.rni.RecvRicochetPacket(oc.conn)
if err != nil { if err != nil {
oc.Close()
return return
} }
for _, packet := range packets { if len(packet.Data) == 0 {
service.OnChannelClosed(oc, packet.Channel)
continue
}
if len(packet.Data) == 0 { if packet.Channel == 0 {
service.OnChannelClosed(oc, packet.Channel)
res := new(Protocol_Data_Control.Packet)
err := proto.Unmarshal(packet.Data[:], res)
if err != nil {
service.OnGenericError(oc, packet.Channel)
continue continue
} }
if packet.Channel == 0 { if res.GetOpenChannel() != nil {
opm := res.GetOpenChannel()
res := new(Protocol_Data_Control.Packet) if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" {
err := proto.Unmarshal(packet.Data[:], res) // Channel is already in use.
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
if err != nil {
service.OnGenericError(oc, packet.Channel)
continue continue
} }
if res.GetOpenChannel() != nil { // If I am a Client, the server can only open even numbered channels
opm := res.GetOpenChannel() if oc.Client && opm.GetChannelIdentifier()%2 != 0 {
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
continue
}
if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" { // If I am a Server, the client can only open odd numbered channels
// Channel is already in use. if !oc.Client && opm.GetChannelIdentifier()%2 != 1 {
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
continue
}
switch opm.GetChannelType() {
case "im.ricochet.auth.hidden-service":
if oc.Client {
// Servers are authed by default and can't auth with hidden-service
service.OnBadUsageError(oc, opm.GetChannelIdentifier()) service.OnBadUsageError(oc, opm.GetChannelIdentifier())
continue } else if oc.IsAuthed {
} // Can't auth if already authed
// If I am a Client, the server can only open even numbered channels
if oc.Client && opm.GetChannelIdentifier()%2 != 0 {
service.OnBadUsageError(oc, opm.GetChannelIdentifier()) service.OnBadUsageError(oc, opm.GetChannelIdentifier())
continue } else if oc.HasChannel("im.ricochet.auth.hidden-service") {
} // Can't open more than 1 auth channel
// If I am a Server, the client can only open odd numbered channels
if !oc.Client && opm.GetChannelIdentifier()%2 != 1 {
service.OnBadUsageError(oc, opm.GetChannelIdentifier()) service.OnBadUsageError(oc, opm.GetChannelIdentifier())
continue } else {
} clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie)
if err == nil {
switch opm.GetChannelType() { clientCookieB := [16]byte{}
case "im.ricochet.auth.hidden-service": copy(clientCookieB[:], clientCookie.([]byte)[:])
if oc.Client { service.OnAuthenticationRequest(oc, opm.GetChannelIdentifier(), clientCookieB)
// Servers are authed by default and can't auth with hidden-service
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
} else if oc.IsAuthed {
// Can't auth if already authed
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
} else if oc.HasChannel("im.ricochet.auth.hidden-service") {
// Can't open more than 1 auth channel
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
} else { } else {
clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie) // Must include Client Cookie
if err == nil { service.OnBadUsageError(oc, opm.GetChannelIdentifier())
clientCookieB := [16]byte{} }
copy(clientCookieB[:], clientCookie.([]byte)[:]) }
service.OnAuthenticationRequest(oc, opm.GetChannelIdentifier(), clientCookieB) case "im.ricochet.chat":
} else { if !oc.IsAuthed {
// Must include Client Cookie // Can't open chat channel if not authorized
service.OnBadUsageError(oc, opm.GetChannelIdentifier()) service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
} else if !service.IsKnownContact(oc.OtherHostname) {
// Can't open chat channel if not a known contact
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
} else {
service.OnOpenChannelRequest(oc, opm.GetChannelIdentifier(), "im.ricochet.chat")
}
case "im.ricochet.contact.request":
if oc.Client {
// Servers are not allowed to send contact requests
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
} else if !oc.IsAuthed {
// Can't open a contact channel if not authed
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
} else if oc.HasChannel("im.ricochet.contact.request") {
// Only 1 contact channel is allowed to be open at a time
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
} else {
contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest)
if err == nil {
contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest)
if check {
service.OnContactRequest(oc, opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText())
break
} }
} }
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
}
default:
service.OnUnknownTypeError(oc, opm.GetChannelIdentifier())
}
} else if res.GetChannelResult() != nil {
crm := res.GetChannelResult()
if crm.GetOpened() {
switch oc.GetChannelType(crm.GetChannelIdentifier()) {
case "im.ricochet.auth.hidden-service":
serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie)
if err == nil {
serverCookieB := [16]byte{}
copy(serverCookieB[:], serverCookie.([]byte)[:])
service.OnAuthenticationChallenge(oc, crm.GetChannelIdentifier(), serverCookieB)
} else {
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
}
case "im.ricochet.chat": case "im.ricochet.chat":
if !oc.IsAuthed { service.OnOpenChannelRequestSuccess(oc, crm.GetChannelIdentifier())
// Can't open chat channel if not authorized
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
} else if !service.IsKnownContact(oc.OtherHostname) {
// Can't open chat channel if not a known contact
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
} else {
service.OnOpenChannelRequest(oc, opm.GetChannelIdentifier(), "im.ricochet.chat")
}
case "im.ricochet.contact.request": case "im.ricochet.contact.request":
if oc.Client { responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response)
// Servers are not allowed to send contact requests if err == nil {
service.OnBadUsageError(oc, opm.GetChannelIdentifier()) response, check := responseI.(*Protocol_Data_ContactRequest.Response)
} else if !oc.IsAuthed { if check {
// Can't open a contact channel if not authed service.OnContactRequestAck(oc, crm.GetChannelIdentifier(), response.GetStatus().String())
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier()) break
} else if oc.HasChannel("im.ricochet.contact.request") {
// Only 1 contact channel is allowed to be open at a time
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
} else {
contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest)
if err == nil {
contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest)
if check {
service.OnContactRequest(oc, opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText())
break
}
} }
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
} }
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
default: default:
service.OnUnknownTypeError(oc, opm.GetChannelIdentifier()) service.OnBadUsageError(oc, crm.GetChannelIdentifier())
}
} else if res.GetChannelResult() != nil {
crm := res.GetChannelResult()
if crm.GetOpened() {
switch oc.GetChannelType(crm.GetChannelIdentifier()) {
case "im.ricochet.auth.hidden-service":
serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie)
if err == nil {
serverCookieB := [16]byte{}
copy(serverCookieB[:], serverCookie.([]byte)[:])
service.OnAuthenticationChallenge(oc, crm.GetChannelIdentifier(), serverCookieB)
} else {
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
}
case "im.ricochet.chat":
service.OnOpenChannelRequestSuccess(oc, crm.GetChannelIdentifier())
case "im.ricochet.contact.request":
responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response)
if err == nil {
response, check := responseI.(*Protocol_Data_ContactRequest.Response)
if check {
service.OnContactRequestAck(oc, crm.GetChannelIdentifier(), response.GetStatus().String())
break
}
}
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
default:
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
}
} else {
if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" {
service.OnFailedChannelOpen(oc, crm.GetChannelIdentifier(), crm.GetCommonError().String())
} else {
oc.CloseChannel(crm.GetChannelIdentifier())
}
} }
} else { } else {
// Unknown Message if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" {
oc.CloseChannel(packet.Channel) service.OnFailedChannelOpen(oc, crm.GetChannelIdentifier(), crm.GetCommonError().String())
} else {
oc.CloseChannel(crm.GetChannelIdentifier())
}
} }
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" { } else {
res := new(Protocol_Data_AuthHiddenService.Packet) // Unknown Message
oc.CloseChannel(packet.Channel)
}
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" {
res := new(Protocol_Data_AuthHiddenService.Packet)
err := proto.Unmarshal(packet.Data[:], res)
if err != nil {
oc.CloseChannel(packet.Channel)
continue
}
if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs
service.OnAuthenticationProof(oc, packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature(), service.IsKnownContact(oc.OtherHostname))
} else if res.GetResult() != nil && oc.Client { // Only Servers Send Results
service.OnAuthenticationResult(oc, packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact())
} else {
// If neither of the above are satisfied we just close the connection
oc.Close()
}
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" {
// NOTE: These auth checks should be redundant, however they
// are included here for defense-in-depth if for some reason
// a previously authed connection becomes untrusted / not known and
// the state is not cleaned up.
if !oc.IsAuthed {
// Can't send chat messages if not authorized
service.OnUnauthorizedError(oc, packet.Channel)
} else if !service.IsKnownContact(oc.OtherHostname) {
// Can't send chat message if not a known contact
service.OnUnauthorizedError(oc, packet.Channel)
} else {
res := new(Protocol_Data_Chat.Packet)
err := proto.Unmarshal(packet.Data[:], res) err := proto.Unmarshal(packet.Data[:], res)
if err != nil { if err != nil {
@ -246,73 +274,42 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService
continue continue
} }
if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs if res.GetChatMessage() != nil {
service.OnAuthenticationProof(oc, packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature(), service.IsKnownContact(oc.OtherHostname)) service.OnChatMessage(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText())
} else if res.GetResult() != nil && oc.Client { // Only Servers Send Results } else if res.GetChatAcknowledge() != nil {
service.OnAuthenticationResult(oc, packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact()) service.OnChatMessageAck(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()))
} else { } else {
// If neither of the above are satisfied we just close the connection // If neither of the above are satisfied we just close the connection
oc.Close() oc.Close()
} }
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" {
// NOTE: These auth checks should be redundant, however they
// are included here for defense-in-depth if for some reason
// a previously authed connection becomes untrusted / not known and
// the state is not cleaned up.
if !oc.IsAuthed {
// Can't send chat messages if not authorized
service.OnUnauthorizedError(oc, packet.Channel)
} else if !service.IsKnownContact(oc.OtherHostname) {
// Can't send chat message if not a known contact
service.OnUnauthorizedError(oc, packet.Channel)
} else {
res := new(Protocol_Data_Chat.Packet)
err := proto.Unmarshal(packet.Data[:], res)
if err != nil {
oc.CloseChannel(packet.Channel)
continue
}
if res.GetChatMessage() != nil {
service.OnChatMessage(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText())
} else if res.GetChatAcknowledge() != nil {
service.OnChatMessageAck(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()))
} else {
// If neither of the above are satisfied we just close the connection
oc.Close()
}
}
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" {
// NOTE: These auth checks should be redundant, however they
// are included here for defense-in-depth if for some reason
// a previously authed connection becomes untrusted / not known and
// the state is not cleaned up.
if !oc.Client {
// Clients are not allowed to send contact request responses
service.OnBadUsageError(oc, packet.Channel)
} else if !oc.IsAuthed {
// Can't send a contact request if not authed
service.OnBadUsageError(oc, packet.Channel)
} else {
res := new(Protocol_Data_ContactRequest.Response)
err := proto.Unmarshal(packet.Data[:], res)
log.Printf("%v", res)
if err != nil {
oc.CloseChannel(packet.Channel)
continue
}
service.OnContactRequestAck(oc, packet.Channel, res.GetStatus().String())
}
} else if oc.GetChannelType(packet.Channel) == "none" {
// Invalid Channel Assignment
oc.CloseChannel(packet.Channel)
} else {
oc.Close()
} }
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" {
// NOTE: These auth checks should be redundant, however they
// are included here for defense-in-depth if for some reason
// a previously authed connection becomes untrusted / not known and
// the state is not cleaned up.
if !oc.Client {
// Clients are not allowed to send contact request responses
service.OnBadUsageError(oc, packet.Channel)
} else if !oc.IsAuthed {
// Can't send a contact request if not authed
service.OnBadUsageError(oc, packet.Channel)
} else {
res := new(Protocol_Data_ContactRequest.Response)
err := proto.Unmarshal(packet.Data[:], res)
log.Printf("%v", res)
if err != nil {
oc.CloseChannel(packet.Channel)
continue
}
service.OnContactRequestAck(oc, packet.Channel, res.GetStatus().String())
}
} else if oc.GetChannelType(packet.Channel) == "none" {
// Invalid Channel Assignment
oc.CloseChannel(packet.Channel)
} else {
oc.Close()
} }
} }
} }

View File

@ -1,10 +1,10 @@
package utils package utils
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"net" "io"
"strconv"
) )
// RicochetData is a structure containing the raw data and the channel it the // RicochetData is a structure containing the raw data and the channel it the
@ -14,79 +14,67 @@ type RicochetData struct {
Data []byte Data []byte
} }
func (rd RicochetData) Equals(other RicochetData) bool {
return rd.Channel == other.Channel && bytes.Equal(rd.Data, other.Data)
}
// RicochetNetworkInterface abstract operations that interact with ricochet's // RicochetNetworkInterface abstract operations that interact with ricochet's
// packet layer. // packet layer.
type RicochetNetworkInterface interface { type RicochetNetworkInterface interface {
Recv(conn net.Conn) ([]byte, error) SendRicochetPacket(dst io.Writer, channel int32, data []byte) error
SendRicochetPacket(conn net.Conn, channel int32, data []byte) RecvRicochetPacket(reader io.Reader) (RicochetData, error)
RecvRicochetPackets(conn net.Conn) ([]RicochetData, error)
} }
// RicochetNetwork is a concrete implementation of the RicochetNetworkInterface // RicochetNetwork is a concrete implementation of the RicochetNetworkInterface
type RicochetNetwork struct { type RicochetNetwork struct {
} }
// Recv reads data from the client, and returns the raw byte array, else error.
func (rn *RicochetNetwork) Recv(conn net.Conn) ([]byte, error) {
buf := make([]byte, 4096)
n, err := conn.Read(buf)
if err != nil {
return nil, err
}
ret := make([]byte, n)
copy(ret[:], buf[:])
return ret, nil
}
// SendRicochetPacket places the data into a structure needed for the client to // SendRicochetPacket places the data into a structure needed for the client to
// decode the packet and writes the packet to the network. // decode the packet and writes the packet to the network.
func (rn *RicochetNetwork) SendRicochetPacket(conn net.Conn, channel int32, data []byte) { func (rn *RicochetNetwork) SendRicochetPacket(dst io.Writer, channel int32, data []byte) error {
header := make([]byte, 4+len(data)) packet := make([]byte, 4+len(data))
header[0] = byte(len(header) >> 8) if len(packet) > 65535 {
header[1] = byte(len(header) & 0x00FF) return errors.New("packet too large")
header[2] = 0x00 }
header[3] = byte(channel) binary.BigEndian.PutUint16(packet[0:2], uint16(len(packet)))
copy(header[4:], data[:]) if channel < 0 || channel > 65535 {
conn.Write(header) return errors.New("invalid channel ID")
}
binary.BigEndian.PutUint16(packet[2:4], uint16(channel))
copy(packet[4:], data[:])
for pos := 0; pos < len(packet); {
n, err := dst.Write(packet[pos:])
if err != nil {
return err
}
pos += n
}
return nil
} }
// RecvRicochetPackets returns an array of new messages received from the ricochet client // RecvRicochetPacket returns the next packet from reader as a RicochetData
func (rn *RicochetNetwork) RecvRicochetPackets(conn net.Conn) ([]RicochetData, error) { // structure, or an error.
buf, err := rn.Recv(conn) func (rn *RicochetNetwork) RecvRicochetPacket(reader io.Reader) (RicochetData, error) {
if err != nil && len(buf) < 4 { packet := RicochetData{}
return nil, errors.New("failed to retrieve new messages from the client")
// Read the four-byte header to get packet length
header := make([]byte, 4)
if _, err := io.ReadAtLeast(reader, header, len(header)); err != nil {
return packet, err
} }
pos := 0 size := int(binary.BigEndian.Uint16(header[0:2]))
finished := false if size < 4 {
var datas []RicochetData return packet, errors.New("invalid packet length")
for !finished {
size := int(binary.BigEndian.Uint16(buf[pos+0 : pos+2]))
channel := int(binary.BigEndian.Uint16(buf[pos+2 : pos+4]))
if size < 4 {
return datas, errors.New("invalid ricochet packet received (size=" + strconv.Itoa(size) + ")")
}
if pos+size > len(buf) {
return datas, errors.New("partial data packet received")
}
data := RicochetData{}
data.Channel = int32(channel)
if pos+4 >= len(buf) {
data.Data = make([]byte, 0)
} else {
data.Data = buf[pos+4 : pos+size]
}
datas = append(datas, data)
pos += size
if pos >= len(buf) {
finished = true
}
} }
return datas, nil
packet.Channel = int32(binary.BigEndian.Uint16(header[2:4]))
packet.Data = make([]byte, size-4)
if _, err := io.ReadAtLeast(reader, packet.Data, len(packet.Data)); err != nil {
return packet, err
}
return packet, nil
} }

View File

@ -1,171 +1,105 @@
package utils package utils
import "testing" import (
import "net" "bytes"
import "time" "io"
"testing"
"testing/iotest"
)
type MockConn struct { // Valid packets and their encoded forms
Written []byte var packetTests = []struct {
MockOutput []byte packet RicochetData
encoded []byte
}{
{RicochetData{1, []byte{}}, []byte{0x00, 0x04, 0x00, 0x01}},
{RicochetData{65535, []byte{0xDE, 0xAD, 0xBE, 0xEF}}, []byte{0x00, 0x08, 0xFF, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}},
{RicochetData{2, make([]byte, 65531)}, append([]byte{0xFF, 0xFF, 0x00, 0x02}, make([]byte, 65531)...)},
} }
func (mc *MockConn) Read(b []byte) (int, error) { // Test sending valid packets
copy(b[:], mc.MockOutput[:]) func TestSendRicochetPacket(t *testing.T) {
return len(mc.MockOutput), nil
}
func (mc *MockConn) Write(written []byte) (int, error) {
mc.Written = written
return 0, nil
}
func (mc *MockConn) LocalAddr() net.Addr {
return nil
}
func (mc *MockConn) RemoteAddr() net.Addr {
return nil
}
func (mc *MockConn) Close() error {
return nil
}
func (mc *MockConn) SetDeadline(t time.Time) error {
return nil
}
func (mc *MockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (mc *MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
func TestSentRicochetPacket(t *testing.T) {
conn := new(MockConn)
rni := RicochetNetwork{} rni := RicochetNetwork{}
rni.SendRicochetPacket(conn, 1, []byte{}) for _, td := range packetTests {
if len(conn.Written) != 4 && conn.Written[0] != 0x00 && conn.Written[1] != 0x00 && conn.Written[2] != 0x01 && conn.Written[3] != 0x00 { var buf bytes.Buffer
t.Errorf("Output of SentRicochetPacket was Unexpected: %x", conn.Written) err := rni.SendRicochetPacket(&buf, td.packet.Channel, td.packet.Data)
if err != nil {
t.Errorf("Error sending packet %v: %v", td.packet, err)
} else if !bytes.Equal(buf.Bytes(), td.encoded) {
t.Errorf("Expected serialized packet %x but got %x", td.encoded, buf.Bytes())
}
} }
} }
func TestRecv(t *testing.T) { // Test sending invalid packets
conn := new(MockConn) func TestSendRicochetPacket_Invalid(t *testing.T) {
conn.MockOutput = []byte{0xDE, 0xAD, 0xBE, 0xEF}
rni := RicochetNetwork{} rni := RicochetNetwork{}
buf, err := rni.Recv(conn) invalidPackets := []RicochetData{
if err != nil || len(buf) != 4 || buf[0] != 0xDE || buf[1] != 0xAD || buf[2] != 0xBE || buf[3] != 0xEF { RicochetData{-1, []byte{}},
t.Errorf("Output of Recv was Unexpected: %x", buf) RicochetData{65536, []byte{}},
RicochetData{0, make([]byte, 65532)},
}
for _, td := range invalidPackets {
var buf bytes.Buffer
err := rni.SendRicochetPacket(&buf, td.Channel, td.Data)
// Expect error
if err == nil {
t.Errorf("Expected error when sending invalid packet %v", td)
}
} }
} }
// Test receiving valid packets
func TestRecvRicochetPacket(t *testing.T) { func TestRecvRicochetPacket(t *testing.T) {
conn := new(MockConn) var buf bytes.Buffer
conn.MockOutput = []byte{00, 0x04, 0x00, 0x01} for _, td := range packetTests {
if _, err := buf.Write(td.encoded); err != nil {
t.Error(err)
return
}
}
// Use a HalfReader to test behavior on short socket reads also
reader := iotest.HalfReader(&buf)
rni := RicochetNetwork{} rni := RicochetNetwork{}
rp, err := rni.RecvRicochetPackets(conn)
if err != nil { for _, td := range packetTests {
t.Errorf("error extracting ricochet packets: %v", err) packet, err := rni.RecvRicochetPacket(reader)
return if err != nil {
} t.Errorf("Error receiving packet %v: %v", td.packet, err)
return
if len(rp) != 1 { } else if !packet.Equals(td.packet) {
t.Errorf("unexpected number of ricochet packets: %d", len(rp)) t.Errorf("Expected unserialized packet %v but got %v", td.packet, packet)
} else {
if rp[0].Channel != 1 {
t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel)
}
if len(rp[0].Data) != 0 {
t.Errorf("expected emptry packet, instead got %x", rp[0].Data)
} }
} }
if packet, err := rni.RecvRicochetPacket(reader); err != io.EOF {
if err != nil {
t.Errorf("Expected EOF on packet stream but received error: %v", err)
} else {
t.Errorf("Expected EOF but received packet: %v", packet)
}
}
} }
func TestRecvRicochetPacketInvalid(t *testing.T) { // Test receiving invalid packets
conn := new(MockConn) func TestRecvRicochetPacket_Invalid(t *testing.T) {
conn.MockOutput = []byte{00, 0x01, 0x00, 0x01}
rni := RicochetNetwork{} rni := RicochetNetwork{}
_, err := rni.RecvRicochetPackets(conn) invalidPackets := [][]byte{
[]byte{0x00, 0x00, 0x00, 0x00},
if err == nil { []byte{0x00, 0x03, 0x00, 0x00},
t.Errorf("recv should have errored due to invalid packets %v", err) []byte{0xff},
[]byte{0x00, 0x06, 0x00, 0x00, 0x00},
[]byte{},
} }
conn.MockOutput = []byte{00, 0x0A, 0x00, 0x01} for _, td := range invalidPackets {
buf := bytes.NewBuffer(td)
_, err = rni.RecvRicochetPackets(conn) packet, err := rni.RecvRicochetPacket(buf)
// Expect error
if err == nil { if err == nil {
t.Errorf("recv should have errored due to invalid packets %v", err) t.Errorf("Expected error when sending invalid packet %x, got packet %v", td, packet)
}
} }
}
func TestRecvRicochetPacketLong(t *testing.T) {
conn := new(MockConn)
conn.MockOutput = []byte{0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}
rni := RicochetNetwork{}
rp, err := rni.RecvRicochetPackets(conn)
if err != nil {
t.Errorf("error extracting ricochet packets: %v", err)
return
}
if len(rp) != 1 {
t.Errorf("unexpected number of ricochet packets: %d", len(rp))
} else {
if rp[0].Channel != 255 {
t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel)
}
if len(rp[0].Data) != 4 || rp[0].Data[0] != 0xDE || rp[0].Data[1] != 0xAD || rp[0].Data[2] != 0xBE || rp[0].Data[3] != 0xEF {
t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data)
}
}
}
func TestRecvRicochetPacketMultiplex(t *testing.T) {
conn := new(MockConn)
conn.MockOutput = []byte{0x00, 0x04, 0x00, 0x01, 0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}
rni := RicochetNetwork{}
rp, err := rni.RecvRicochetPackets(conn)
if err != nil {
t.Errorf("error extracting ricochet packets: %v", err)
return
}
if len(rp) != 2 {
t.Errorf("unexpected number of ricochet packets, expected 2 gt: %d", len(rp))
} else {
if rp[0].Channel != 1 {
t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel)
}
if len(rp[0].Data) != 0 {
t.Errorf("expected empty packet, instead got %x", rp[0].Data)
}
if rp[1].Channel != 255 {
t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel)
}
if len(rp[1].Data) != 4 || rp[1].Data[0] != 0xDE || rp[1].Data[1] != 0xAD || rp[1].Data[2] != 0xBE || rp[1].Data[3] != 0xEF {
t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data)
}
}
} }