fix off by one error when removing ricochet instances #11
349
ricochet.go
349
ricochet.go
|
@ -98,147 +98,175 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService
|
|||
return
|
||||
}
|
||||
|
||||
packets, err := r.rni.RecvRicochetPackets(oc.conn)
|
||||
|
||||
packet, err := r.rni.RecvRicochetPacket(oc.conn)
|
||||
if err != nil {
|
||||
oc.Close()
|
||||
return
|
||||
}
|
||||
|
||||
for _, packet := range packets {
|
||||
if len(packet.Data) == 0 {
|
||||
service.OnChannelClosed(oc, packet.Channel)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(packet.Data) == 0 {
|
||||
service.OnChannelClosed(oc, packet.Channel)
|
||||
if packet.Channel == 0 {
|
||||
|
||||
res := new(Protocol_Data_Control.Packet)
|
||||
err := proto.Unmarshal(packet.Data[:], res)
|
||||
|
||||
if err != nil {
|
||||
service.OnGenericError(oc, packet.Channel)
|
||||
continue
|
||||
}
|
||||
|
||||
if packet.Channel == 0 {
|
||||
if res.GetOpenChannel() != nil {
|
||||
opm := res.GetOpenChannel()
|
||||
|
||||
res := new(Protocol_Data_Control.Packet)
|
||||
err := proto.Unmarshal(packet.Data[:], res)
|
||||
|
||||
if err != nil {
|
||||
service.OnGenericError(oc, packet.Channel)
|
||||
if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" {
|
||||
// Channel is already in use.
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
continue
|
||||
}
|
||||
|
||||
if res.GetOpenChannel() != nil {
|
||||
opm := res.GetOpenChannel()
|
||||
// If I am a Client, the server can only open even numbered channels
|
||||
if oc.Client && opm.GetChannelIdentifier()%2 != 0 {
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
continue
|
||||
}
|
||||
|
||||
if oc.GetChannelType(opm.GetChannelIdentifier()) != "none" {
|
||||
// Channel is already in use.
|
||||
// If I am a Server, the client can only open odd numbered channels
|
||||
if !oc.Client && opm.GetChannelIdentifier()%2 != 1 {
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
continue
|
||||
}
|
||||
|
||||
switch opm.GetChannelType() {
|
||||
case "im.ricochet.auth.hidden-service":
|
||||
if oc.Client {
|
||||
// Servers are authed by default and can't auth with hidden-service
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
continue
|
||||
}
|
||||
|
||||
// If I am a Client, the server can only open even numbered channels
|
||||
if oc.Client && opm.GetChannelIdentifier()%2 != 0 {
|
||||
} else if oc.IsAuthed {
|
||||
// Can't auth if already authed
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
continue
|
||||
}
|
||||
|
||||
// If I am a Server, the client can only open odd numbered channels
|
||||
if !oc.Client && opm.GetChannelIdentifier()%2 != 1 {
|
||||
} else if oc.HasChannel("im.ricochet.auth.hidden-service") {
|
||||
// Can't open more than 1 auth channel
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
continue
|
||||
}
|
||||
|
||||
switch opm.GetChannelType() {
|
||||
case "im.ricochet.auth.hidden-service":
|
||||
if oc.Client {
|
||||
// Servers are authed by default and can't auth with hidden-service
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
} else if oc.IsAuthed {
|
||||
// Can't auth if already authed
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
} else if oc.HasChannel("im.ricochet.auth.hidden-service") {
|
||||
// Can't open more than 1 auth channel
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
} else {
|
||||
clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie)
|
||||
if err == nil {
|
||||
clientCookieB := [16]byte{}
|
||||
copy(clientCookieB[:], clientCookie.([]byte)[:])
|
||||
service.OnAuthenticationRequest(oc, opm.GetChannelIdentifier(), clientCookieB)
|
||||
} else {
|
||||
clientCookie, err := proto.GetExtension(opm, Protocol_Data_AuthHiddenService.E_ClientCookie)
|
||||
if err == nil {
|
||||
clientCookieB := [16]byte{}
|
||||
copy(clientCookieB[:], clientCookie.([]byte)[:])
|
||||
service.OnAuthenticationRequest(oc, opm.GetChannelIdentifier(), clientCookieB)
|
||||
} else {
|
||||
// Must include Client Cookie
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
// Must include Client Cookie
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
}
|
||||
}
|
||||
case "im.ricochet.chat":
|
||||
if !oc.IsAuthed {
|
||||
// Can't open chat channel if not authorized
|
||||
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
|
||||
} else if !service.IsKnownContact(oc.OtherHostname) {
|
||||
// Can't open chat channel if not a known contact
|
||||
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
|
||||
} else {
|
||||
service.OnOpenChannelRequest(oc, opm.GetChannelIdentifier(), "im.ricochet.chat")
|
||||
}
|
||||
case "im.ricochet.contact.request":
|
||||
if oc.Client {
|
||||
// Servers are not allowed to send contact requests
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
} else if !oc.IsAuthed {
|
||||
// Can't open a contact channel if not authed
|
||||
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
|
||||
} else if oc.HasChannel("im.ricochet.contact.request") {
|
||||
// Only 1 contact channel is allowed to be open at a time
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
} else {
|
||||
contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest)
|
||||
if err == nil {
|
||||
contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest)
|
||||
if check {
|
||||
service.OnContactRequest(oc, opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText())
|
||||
break
|
||||
}
|
||||
}
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
}
|
||||
default:
|
||||
service.OnUnknownTypeError(oc, opm.GetChannelIdentifier())
|
||||
}
|
||||
} else if res.GetChannelResult() != nil {
|
||||
crm := res.GetChannelResult()
|
||||
if crm.GetOpened() {
|
||||
switch oc.GetChannelType(crm.GetChannelIdentifier()) {
|
||||
case "im.ricochet.auth.hidden-service":
|
||||
serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie)
|
||||
if err == nil {
|
||||
serverCookieB := [16]byte{}
|
||||
copy(serverCookieB[:], serverCookie.([]byte)[:])
|
||||
service.OnAuthenticationChallenge(oc, crm.GetChannelIdentifier(), serverCookieB)
|
||||
} else {
|
||||
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
|
||||
}
|
||||
case "im.ricochet.chat":
|
||||
if !oc.IsAuthed {
|
||||
// Can't open chat channel if not authorized
|
||||
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
|
||||
} else if !service.IsKnownContact(oc.OtherHostname) {
|
||||
// Can't open chat channel if not a known contact
|
||||
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
|
||||
} else {
|
||||
service.OnOpenChannelRequest(oc, opm.GetChannelIdentifier(), "im.ricochet.chat")
|
||||
}
|
||||
service.OnOpenChannelRequestSuccess(oc, crm.GetChannelIdentifier())
|
||||
case "im.ricochet.contact.request":
|
||||
if oc.Client {
|
||||
// Servers are not allowed to send contact requests
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
} else if !oc.IsAuthed {
|
||||
// Can't open a contact channel if not authed
|
||||
service.OnUnauthorizedError(oc, opm.GetChannelIdentifier())
|
||||
} else if oc.HasChannel("im.ricochet.contact.request") {
|
||||
// Only 1 contact channel is allowed to be open at a time
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
} else {
|
||||
contactRequestI, err := proto.GetExtension(opm, Protocol_Data_ContactRequest.E_ContactRequest)
|
||||
if err == nil {
|
||||
contactRequest, check := contactRequestI.(*Protocol_Data_ContactRequest.ContactRequest)
|
||||
if check {
|
||||
service.OnContactRequest(oc, opm.GetChannelIdentifier(), contactRequest.GetNickname(), contactRequest.GetMessageText())
|
||||
break
|
||||
}
|
||||
responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response)
|
||||
if err == nil {
|
||||
response, check := responseI.(*Protocol_Data_ContactRequest.Response)
|
||||
if check {
|
||||
service.OnContactRequestAck(oc, crm.GetChannelIdentifier(), response.GetStatus().String())
|
||||
break
|
||||
}
|
||||
service.OnBadUsageError(oc, opm.GetChannelIdentifier())
|
||||
}
|
||||
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
|
||||
default:
|
||||
service.OnUnknownTypeError(oc, opm.GetChannelIdentifier())
|
||||
}
|
||||
} else if res.GetChannelResult() != nil {
|
||||
crm := res.GetChannelResult()
|
||||
if crm.GetOpened() {
|
||||
switch oc.GetChannelType(crm.GetChannelIdentifier()) {
|
||||
case "im.ricochet.auth.hidden-service":
|
||||
serverCookie, err := proto.GetExtension(crm, Protocol_Data_AuthHiddenService.E_ServerCookie)
|
||||
if err == nil {
|
||||
serverCookieB := [16]byte{}
|
||||
copy(serverCookieB[:], serverCookie.([]byte)[:])
|
||||
service.OnAuthenticationChallenge(oc, crm.GetChannelIdentifier(), serverCookieB)
|
||||
} else {
|
||||
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
|
||||
}
|
||||
case "im.ricochet.chat":
|
||||
service.OnOpenChannelRequestSuccess(oc, crm.GetChannelIdentifier())
|
||||
case "im.ricochet.contact.request":
|
||||
responseI, err := proto.GetExtension(res.GetChannelResult(), Protocol_Data_ContactRequest.E_Response)
|
||||
if err == nil {
|
||||
response, check := responseI.(*Protocol_Data_ContactRequest.Response)
|
||||
if check {
|
||||
service.OnContactRequestAck(oc, crm.GetChannelIdentifier(), response.GetStatus().String())
|
||||
break
|
||||
}
|
||||
}
|
||||
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
|
||||
default:
|
||||
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
|
||||
}
|
||||
} else {
|
||||
if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" {
|
||||
service.OnFailedChannelOpen(oc, crm.GetChannelIdentifier(), crm.GetCommonError().String())
|
||||
} else {
|
||||
oc.CloseChannel(crm.GetChannelIdentifier())
|
||||
}
|
||||
service.OnBadUsageError(oc, crm.GetChannelIdentifier())
|
||||
}
|
||||
} else {
|
||||
// Unknown Message
|
||||
oc.CloseChannel(packet.Channel)
|
||||
if oc.GetChannelType(crm.GetChannelIdentifier()) != "none" {
|
||||
service.OnFailedChannelOpen(oc, crm.GetChannelIdentifier(), crm.GetCommonError().String())
|
||||
} else {
|
||||
oc.CloseChannel(crm.GetChannelIdentifier())
|
||||
}
|
||||
}
|
||||
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" {
|
||||
res := new(Protocol_Data_AuthHiddenService.Packet)
|
||||
} else {
|
||||
// Unknown Message
|
||||
oc.CloseChannel(packet.Channel)
|
||||
}
|
||||
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.auth.hidden-service" {
|
||||
res := new(Protocol_Data_AuthHiddenService.Packet)
|
||||
err := proto.Unmarshal(packet.Data[:], res)
|
||||
|
||||
if err != nil {
|
||||
oc.CloseChannel(packet.Channel)
|
||||
continue
|
||||
}
|
||||
|
||||
if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs
|
||||
service.OnAuthenticationProof(oc, packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature(), service.IsKnownContact(oc.OtherHostname))
|
||||
} else if res.GetResult() != nil && oc.Client { // Only Servers Send Results
|
||||
service.OnAuthenticationResult(oc, packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact())
|
||||
} else {
|
||||
// If neither of the above are satisfied we just close the connection
|
||||
oc.Close()
|
||||
}
|
||||
|
||||
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" {
|
||||
|
||||
// NOTE: These auth checks should be redundant, however they
|
||||
// are included here for defense-in-depth if for some reason
|
||||
// a previously authed connection becomes untrusted / not known and
|
||||
// the state is not cleaned up.
|
||||
if !oc.IsAuthed {
|
||||
// Can't send chat messages if not authorized
|
||||
service.OnUnauthorizedError(oc, packet.Channel)
|
||||
} else if !service.IsKnownContact(oc.OtherHostname) {
|
||||
// Can't send chat message if not a known contact
|
||||
service.OnUnauthorizedError(oc, packet.Channel)
|
||||
} else {
|
||||
res := new(Protocol_Data_Chat.Packet)
|
||||
err := proto.Unmarshal(packet.Data[:], res)
|
||||
|
||||
if err != nil {
|
||||
|
@ -246,73 +274,42 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService
|
|||
continue
|
||||
}
|
||||
|
||||
if res.GetProof() != nil && !oc.Client { // Only Clients Send Proofs
|
||||
service.OnAuthenticationProof(oc, packet.Channel, res.GetProof().GetPublicKey(), res.GetProof().GetSignature(), service.IsKnownContact(oc.OtherHostname))
|
||||
} else if res.GetResult() != nil && oc.Client { // Only Servers Send Results
|
||||
service.OnAuthenticationResult(oc, packet.Channel, res.GetResult().GetAccepted(), res.GetResult().GetIsKnownContact())
|
||||
if res.GetChatMessage() != nil {
|
||||
service.OnChatMessage(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText())
|
||||
} else if res.GetChatAcknowledge() != nil {
|
||||
service.OnChatMessageAck(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()))
|
||||
} else {
|
||||
// If neither of the above are satisfied we just close the connection
|
||||
oc.Close()
|
||||
}
|
||||
|
||||
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.chat" {
|
||||
|
||||
// NOTE: These auth checks should be redundant, however they
|
||||
// are included here for defense-in-depth if for some reason
|
||||
// a previously authed connection becomes untrusted / not known and
|
||||
// the state is not cleaned up.
|
||||
if !oc.IsAuthed {
|
||||
// Can't send chat messages if not authorized
|
||||
service.OnUnauthorizedError(oc, packet.Channel)
|
||||
} else if !service.IsKnownContact(oc.OtherHostname) {
|
||||
// Can't send chat message if not a known contact
|
||||
service.OnUnauthorizedError(oc, packet.Channel)
|
||||
} else {
|
||||
res := new(Protocol_Data_Chat.Packet)
|
||||
err := proto.Unmarshal(packet.Data[:], res)
|
||||
|
||||
if err != nil {
|
||||
oc.CloseChannel(packet.Channel)
|
||||
continue
|
||||
}
|
||||
|
||||
if res.GetChatMessage() != nil {
|
||||
service.OnChatMessage(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()), res.GetChatMessage().GetMessageText())
|
||||
} else if res.GetChatAcknowledge() != nil {
|
||||
service.OnChatMessageAck(oc, packet.Channel, int32(res.GetChatMessage().GetMessageId()))
|
||||
} else {
|
||||
// If neither of the above are satisfied we just close the connection
|
||||
oc.Close()
|
||||
}
|
||||
}
|
||||
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" {
|
||||
|
||||
// NOTE: These auth checks should be redundant, however they
|
||||
// are included here for defense-in-depth if for some reason
|
||||
// a previously authed connection becomes untrusted / not known and
|
||||
// the state is not cleaned up.
|
||||
if !oc.Client {
|
||||
// Clients are not allowed to send contact request responses
|
||||
service.OnBadUsageError(oc, packet.Channel)
|
||||
} else if !oc.IsAuthed {
|
||||
// Can't send a contact request if not authed
|
||||
service.OnBadUsageError(oc, packet.Channel)
|
||||
} else {
|
||||
res := new(Protocol_Data_ContactRequest.Response)
|
||||
err := proto.Unmarshal(packet.Data[:], res)
|
||||
log.Printf("%v", res)
|
||||
if err != nil {
|
||||
oc.CloseChannel(packet.Channel)
|
||||
continue
|
||||
}
|
||||
service.OnContactRequestAck(oc, packet.Channel, res.GetStatus().String())
|
||||
}
|
||||
} else if oc.GetChannelType(packet.Channel) == "none" {
|
||||
// Invalid Channel Assignment
|
||||
oc.CloseChannel(packet.Channel)
|
||||
} else {
|
||||
oc.Close()
|
||||
}
|
||||
} else if oc.GetChannelType(packet.Channel) == "im.ricochet.contact.request" {
|
||||
|
||||
// NOTE: These auth checks should be redundant, however they
|
||||
// are included here for defense-in-depth if for some reason
|
||||
// a previously authed connection becomes untrusted / not known and
|
||||
// the state is not cleaned up.
|
||||
if !oc.Client {
|
||||
// Clients are not allowed to send contact request responses
|
||||
service.OnBadUsageError(oc, packet.Channel)
|
||||
} else if !oc.IsAuthed {
|
||||
// Can't send a contact request if not authed
|
||||
service.OnBadUsageError(oc, packet.Channel)
|
||||
} else {
|
||||
res := new(Protocol_Data_ContactRequest.Response)
|
||||
err := proto.Unmarshal(packet.Data[:], res)
|
||||
log.Printf("%v", res)
|
||||
if err != nil {
|
||||
oc.CloseChannel(packet.Channel)
|
||||
continue
|
||||
}
|
||||
service.OnContactRequestAck(oc, packet.Channel, res.GetStatus().String())
|
||||
}
|
||||
} else if oc.GetChannelType(packet.Channel) == "none" {
|
||||
// Invalid Channel Assignment
|
||||
oc.CloseChannel(packet.Channel)
|
||||
} else {
|
||||
oc.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"strconv"
|
||||
"io"
|
||||
)
|
||||
|
||||
// RicochetData is a structure containing the raw data and the channel it the
|
||||
|
@ -14,79 +14,67 @@ type RicochetData struct {
|
|||
Data []byte
|
||||
}
|
||||
|
||||
func (rd RicochetData) Equals(other RicochetData) bool {
|
||||
return rd.Channel == other.Channel && bytes.Equal(rd.Data, other.Data)
|
||||
}
|
||||
|
||||
// RicochetNetworkInterface abstract operations that interact with ricochet's
|
||||
// packet layer.
|
||||
type RicochetNetworkInterface interface {
|
||||
Recv(conn net.Conn) ([]byte, error)
|
||||
SendRicochetPacket(conn net.Conn, channel int32, data []byte)
|
||||
RecvRicochetPackets(conn net.Conn) ([]RicochetData, error)
|
||||
SendRicochetPacket(dst io.Writer, channel int32, data []byte) error
|
||||
RecvRicochetPacket(reader io.Reader) (RicochetData, error)
|
||||
}
|
||||
|
||||
// RicochetNetwork is a concrete implementation of the RicochetNetworkInterface
|
||||
type RicochetNetwork struct {
|
||||
}
|
||||
|
||||
// Recv reads data from the client, and returns the raw byte array, else error.
|
||||
func (rn *RicochetNetwork) Recv(conn net.Conn) ([]byte, error) {
|
||||
buf := make([]byte, 4096)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret := make([]byte, n)
|
||||
copy(ret[:], buf[:])
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// SendRicochetPacket places the data into a structure needed for the client to
|
||||
// decode the packet and writes the packet to the network.
|
||||
func (rn *RicochetNetwork) SendRicochetPacket(conn net.Conn, channel int32, data []byte) {
|
||||
header := make([]byte, 4+len(data))
|
||||
header[0] = byte(len(header) >> 8)
|
||||
header[1] = byte(len(header) & 0x00FF)
|
||||
header[2] = 0x00
|
||||
header[3] = byte(channel)
|
||||
copy(header[4:], data[:])
|
||||
conn.Write(header)
|
||||
func (rn *RicochetNetwork) SendRicochetPacket(dst io.Writer, channel int32, data []byte) error {
|
||||
packet := make([]byte, 4+len(data))
|
||||
if len(packet) > 65535 {
|
||||
return errors.New("packet too large")
|
||||
}
|
||||
binary.BigEndian.PutUint16(packet[0:2], uint16(len(packet)))
|
||||
if channel < 0 || channel > 65535 {
|
||||
return errors.New("invalid channel ID")
|
||||
}
|
||||
binary.BigEndian.PutUint16(packet[2:4], uint16(channel))
|
||||
copy(packet[4:], data[:])
|
||||
|
||||
for pos := 0; pos < len(packet); {
|
||||
n, err := dst.Write(packet[pos:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pos += n
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecvRicochetPackets returns an array of new messages received from the ricochet client
|
||||
func (rn *RicochetNetwork) RecvRicochetPackets(conn net.Conn) ([]RicochetData, error) {
|
||||
buf, err := rn.Recv(conn)
|
||||
if err != nil && len(buf) < 4 {
|
||||
return nil, errors.New("failed to retrieve new messages from the client")
|
||||
// RecvRicochetPacket returns the next packet from reader as a RicochetData
|
||||
// structure, or an error.
|
||||
func (rn *RicochetNetwork) RecvRicochetPacket(reader io.Reader) (RicochetData, error) {
|
||||
packet := RicochetData{}
|
||||
|
||||
// Read the four-byte header to get packet length
|
||||
header := make([]byte, 4)
|
||||
if _, err := io.ReadAtLeast(reader, header, len(header)); err != nil {
|
||||
return packet, err
|
||||
}
|
||||
|
||||
pos := 0
|
||||
finished := false
|
||||
var datas []RicochetData
|
||||
|
||||
for !finished {
|
||||
size := int(binary.BigEndian.Uint16(buf[pos+0 : pos+2]))
|
||||
channel := int(binary.BigEndian.Uint16(buf[pos+2 : pos+4]))
|
||||
|
||||
if size < 4 {
|
||||
return datas, errors.New("invalid ricochet packet received (size=" + strconv.Itoa(size) + ")")
|
||||
}
|
||||
|
||||
if pos+size > len(buf) {
|
||||
return datas, errors.New("partial data packet received")
|
||||
}
|
||||
|
||||
data := RicochetData{}
|
||||
data.Channel = int32(channel)
|
||||
|
||||
if pos+4 >= len(buf) {
|
||||
data.Data = make([]byte, 0)
|
||||
} else {
|
||||
data.Data = buf[pos+4 : pos+size]
|
||||
}
|
||||
|
||||
datas = append(datas, data)
|
||||
pos += size
|
||||
if pos >= len(buf) {
|
||||
finished = true
|
||||
}
|
||||
size := int(binary.BigEndian.Uint16(header[0:2]))
|
||||
if size < 4 {
|
||||
return packet, errors.New("invalid packet length")
|
||||
}
|
||||
return datas, nil
|
||||
|
||||
packet.Channel = int32(binary.BigEndian.Uint16(header[2:4]))
|
||||
packet.Data = make([]byte, size-4)
|
||||
|
||||
if _, err := io.ReadAtLeast(reader, packet.Data, len(packet.Data)); err != nil {
|
||||
return packet, err
|
||||
}
|
||||
|
||||
return packet, nil
|
||||
}
|
||||
|
|
|
@ -1,171 +1,105 @@
|
|||
package utils
|
||||
|
||||
import "testing"
|
||||
import "net"
|
||||
import "time"
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
)
|
||||
|
||||
type MockConn struct {
|
||||
Written []byte
|
||||
MockOutput []byte
|
||||
// Valid packets and their encoded forms
|
||||
var packetTests = []struct {
|
||||
packet RicochetData
|
||||
encoded []byte
|
||||
}{
|
||||
{RicochetData{1, []byte{}}, []byte{0x00, 0x04, 0x00, 0x01}},
|
||||
{RicochetData{65535, []byte{0xDE, 0xAD, 0xBE, 0xEF}}, []byte{0x00, 0x08, 0xFF, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}},
|
||||
{RicochetData{2, make([]byte, 65531)}, append([]byte{0xFF, 0xFF, 0x00, 0x02}, make([]byte, 65531)...)},
|
||||
}
|
||||
|
||||
func (mc *MockConn) Read(b []byte) (int, error) {
|
||||
copy(b[:], mc.MockOutput[:])
|
||||
return len(mc.MockOutput), nil
|
||||
}
|
||||
|
||||
func (mc *MockConn) Write(written []byte) (int, error) {
|
||||
mc.Written = written
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (mc *MockConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *MockConn) RemoteAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *MockConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *MockConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *MockConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *MockConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSentRicochetPacket(t *testing.T) {
|
||||
conn := new(MockConn)
|
||||
// Test sending valid packets
|
||||
func TestSendRicochetPacket(t *testing.T) {
|
||||
rni := RicochetNetwork{}
|
||||
rni.SendRicochetPacket(conn, 1, []byte{})
|
||||
if len(conn.Written) != 4 && conn.Written[0] != 0x00 && conn.Written[1] != 0x00 && conn.Written[2] != 0x01 && conn.Written[3] != 0x00 {
|
||||
t.Errorf("Output of SentRicochetPacket was Unexpected: %x", conn.Written)
|
||||
for _, td := range packetTests {
|
||||
var buf bytes.Buffer
|
||||
err := rni.SendRicochetPacket(&buf, td.packet.Channel, td.packet.Data)
|
||||
if err != nil {
|
||||
t.Errorf("Error sending packet %v: %v", td.packet, err)
|
||||
} else if !bytes.Equal(buf.Bytes(), td.encoded) {
|
||||
t.Errorf("Expected serialized packet %x but got %x", td.encoded, buf.Bytes())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecv(t *testing.T) {
|
||||
conn := new(MockConn)
|
||||
conn.MockOutput = []byte{0xDE, 0xAD, 0xBE, 0xEF}
|
||||
// Test sending invalid packets
|
||||
func TestSendRicochetPacket_Invalid(t *testing.T) {
|
||||
rni := RicochetNetwork{}
|
||||
buf, err := rni.Recv(conn)
|
||||
if err != nil || len(buf) != 4 || buf[0] != 0xDE || buf[1] != 0xAD || buf[2] != 0xBE || buf[3] != 0xEF {
|
||||
t.Errorf("Output of Recv was Unexpected: %x", buf)
|
||||
invalidPackets := []RicochetData{
|
||||
RicochetData{-1, []byte{}},
|
||||
RicochetData{65536, []byte{}},
|
||||
RicochetData{0, make([]byte, 65532)},
|
||||
}
|
||||
|
||||
for _, td := range invalidPackets {
|
||||
var buf bytes.Buffer
|
||||
err := rni.SendRicochetPacket(&buf, td.Channel, td.Data)
|
||||
// Expect error
|
||||
if err == nil {
|
||||
t.Errorf("Expected error when sending invalid packet %v", td)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test receiving valid packets
|
||||
func TestRecvRicochetPacket(t *testing.T) {
|
||||
conn := new(MockConn)
|
||||
conn.MockOutput = []byte{00, 0x04, 0x00, 0x01}
|
||||
var buf bytes.Buffer
|
||||
for _, td := range packetTests {
|
||||
if _, err := buf.Write(td.encoded); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Use a HalfReader to test behavior on short socket reads also
|
||||
reader := iotest.HalfReader(&buf)
|
||||
rni := RicochetNetwork{}
|
||||
rp, err := rni.RecvRicochetPackets(conn)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error extracting ricochet packets: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(rp) != 1 {
|
||||
t.Errorf("unexpected number of ricochet packets: %d", len(rp))
|
||||
} else {
|
||||
if rp[0].Channel != 1 {
|
||||
t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel)
|
||||
}
|
||||
|
||||
if len(rp[0].Data) != 0 {
|
||||
t.Errorf("expected emptry packet, instead got %x", rp[0].Data)
|
||||
for _, td := range packetTests {
|
||||
packet, err := rni.RecvRicochetPacket(reader)
|
||||
if err != nil {
|
||||
t.Errorf("Error receiving packet %v: %v", td.packet, err)
|
||||
return
|
||||
} else if !packet.Equals(td.packet) {
|
||||
t.Errorf("Expected unserialized packet %v but got %v", td.packet, packet)
|
||||
}
|
||||
}
|
||||
|
||||
if packet, err := rni.RecvRicochetPacket(reader); err != io.EOF {
|
||||
if err != nil {
|
||||
t.Errorf("Expected EOF on packet stream but received error: %v", err)
|
||||
} else {
|
||||
t.Errorf("Expected EOF but received packet: %v", packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvRicochetPacketInvalid(t *testing.T) {
|
||||
conn := new(MockConn)
|
||||
conn.MockOutput = []byte{00, 0x01, 0x00, 0x01}
|
||||
|
||||
// Test receiving invalid packets
|
||||
func TestRecvRicochetPacket_Invalid(t *testing.T) {
|
||||
rni := RicochetNetwork{}
|
||||
_, err := rni.RecvRicochetPackets(conn)
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("recv should have errored due to invalid packets %v", err)
|
||||
invalidPackets := [][]byte{
|
||||
[]byte{0x00, 0x00, 0x00, 0x00},
|
||||
[]byte{0x00, 0x03, 0x00, 0x00},
|
||||
[]byte{0xff},
|
||||
[]byte{0x00, 0x06, 0x00, 0x00, 0x00},
|
||||
[]byte{},
|
||||
}
|
||||
|
||||
conn.MockOutput = []byte{00, 0x0A, 0x00, 0x01}
|
||||
|
||||
_, err = rni.RecvRicochetPackets(conn)
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("recv should have errored due to invalid packets %v", err)
|
||||
for _, td := range invalidPackets {
|
||||
buf := bytes.NewBuffer(td)
|
||||
packet, err := rni.RecvRicochetPacket(buf)
|
||||
// Expect error
|
||||
if err == nil {
|
||||
t.Errorf("Expected error when sending invalid packet %x, got packet %v", td, packet)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRecvRicochetPacketLong(t *testing.T) {
|
||||
conn := new(MockConn)
|
||||
conn.MockOutput = []byte{0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}
|
||||
|
||||
rni := RicochetNetwork{}
|
||||
rp, err := rni.RecvRicochetPackets(conn)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error extracting ricochet packets: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(rp) != 1 {
|
||||
t.Errorf("unexpected number of ricochet packets: %d", len(rp))
|
||||
} else {
|
||||
if rp[0].Channel != 255 {
|
||||
t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel)
|
||||
}
|
||||
|
||||
if len(rp[0].Data) != 4 || rp[0].Data[0] != 0xDE || rp[0].Data[1] != 0xAD || rp[0].Data[2] != 0xBE || rp[0].Data[3] != 0xEF {
|
||||
t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRecvRicochetPacketMultiplex(t *testing.T) {
|
||||
conn := new(MockConn)
|
||||
conn.MockOutput = []byte{0x00, 0x04, 0x00, 0x01, 0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}
|
||||
|
||||
rni := RicochetNetwork{}
|
||||
rp, err := rni.RecvRicochetPackets(conn)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error extracting ricochet packets: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(rp) != 2 {
|
||||
t.Errorf("unexpected number of ricochet packets, expected 2 gt: %d", len(rp))
|
||||
} else {
|
||||
|
||||
if rp[0].Channel != 1 {
|
||||
t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel)
|
||||
}
|
||||
|
||||
if len(rp[0].Data) != 0 {
|
||||
t.Errorf("expected empty packet, instead got %x", rp[0].Data)
|
||||
}
|
||||
|
||||
if rp[1].Channel != 255 {
|
||||
t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel)
|
||||
}
|
||||
|
||||
if len(rp[1].Data) != 4 || rp[1].Data[0] != 0xDE || rp[1].Data[1] != 0xAD || rp[1].Data[2] != 0xBE || rp[1].Data[3] != 0xEF {
|
||||
t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Reference in New Issue