fix off by one error when removing ricochet instances #11

Closed
dan wants to merge 1 commits from dan/libricochet-go:master into master
3 changed files with 300 additions and 381 deletions

View File

@ -98,14 +98,12 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService
return return
} }
packets, err := r.rni.RecvRicochetPackets(oc.conn) packet, err := r.rni.RecvRicochetPacket(oc.conn)
if err != nil { if err != nil {
oc.Close()
return return
} }
for _, packet := range packets {
if len(packet.Data) == 0 { if len(packet.Data) == 0 {
service.OnChannelClosed(oc, packet.Channel) service.OnChannelClosed(oc, packet.Channel)
continue continue
@ -315,7 +313,6 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService
} }
} }
} }
}
// negotiateVersion Perform version negotiation with the connected host. // negotiateVersion Perform version negotiation with the connected host.
func (r *Ricochet) negotiateVersion(conn net.Conn, outbound bool) (*OpenConnection, error) { func (r *Ricochet) negotiateVersion(conn net.Conn, outbound bool) (*OpenConnection, error) {

View File

@ -1,10 +1,10 @@
package utils package utils
import ( import (
"bytes"
"encoding/binary" "encoding/binary"
"errors" "errors"
"net" "io"
"strconv"
) )
// RicochetData is a structure containing the raw data and the channel it the // RicochetData is a structure containing the raw data and the channel it the
@ -14,79 +14,67 @@ type RicochetData struct {
Data []byte Data []byte
} }
func (rd RicochetData) Equals(other RicochetData) bool {
return rd.Channel == other.Channel && bytes.Equal(rd.Data, other.Data)
}
// RicochetNetworkInterface abstract operations that interact with ricochet's // RicochetNetworkInterface abstract operations that interact with ricochet's
// packet layer. // packet layer.
type RicochetNetworkInterface interface { type RicochetNetworkInterface interface {
Recv(conn net.Conn) ([]byte, error) SendRicochetPacket(dst io.Writer, channel int32, data []byte) error
SendRicochetPacket(conn net.Conn, channel int32, data []byte) RecvRicochetPacket(reader io.Reader) (RicochetData, error)
RecvRicochetPackets(conn net.Conn) ([]RicochetData, error)
} }
// RicochetNetwork is a concrete implementation of the RicochetNetworkInterface // RicochetNetwork is a concrete implementation of the RicochetNetworkInterface
type RicochetNetwork struct { type RicochetNetwork struct {
} }
// Recv reads data from the client, and returns the raw byte array, else error.
func (rn *RicochetNetwork) Recv(conn net.Conn) ([]byte, error) {
buf := make([]byte, 4096)
n, err := conn.Read(buf)
if err != nil {
return nil, err
}
ret := make([]byte, n)
copy(ret[:], buf[:])
return ret, nil
}
// SendRicochetPacket places the data into a structure needed for the client to // SendRicochetPacket places the data into a structure needed for the client to
// decode the packet and writes the packet to the network. // decode the packet and writes the packet to the network.
func (rn *RicochetNetwork) SendRicochetPacket(conn net.Conn, channel int32, data []byte) { func (rn *RicochetNetwork) SendRicochetPacket(dst io.Writer, channel int32, data []byte) error {
header := make([]byte, 4+len(data)) packet := make([]byte, 4+len(data))
header[0] = byte(len(header) >> 8) if len(packet) > 65535 {
header[1] = byte(len(header) & 0x00FF) return errors.New("packet too large")
header[2] = 0x00 }
header[3] = byte(channel) binary.BigEndian.PutUint16(packet[0:2], uint16(len(packet)))
copy(header[4:], data[:]) if channel < 0 || channel > 65535 {
conn.Write(header) return errors.New("invalid channel ID")
}
binary.BigEndian.PutUint16(packet[2:4], uint16(channel))
copy(packet[4:], data[:])
for pos := 0; pos < len(packet); {
n, err := dst.Write(packet[pos:])
if err != nil {
return err
}
pos += n
}
return nil
} }
// RecvRicochetPackets returns an array of new messages received from the ricochet client // RecvRicochetPacket returns the next packet from reader as a RicochetData
func (rn *RicochetNetwork) RecvRicochetPackets(conn net.Conn) ([]RicochetData, error) { // structure, or an error.
buf, err := rn.Recv(conn) func (rn *RicochetNetwork) RecvRicochetPacket(reader io.Reader) (RicochetData, error) {
if err != nil && len(buf) < 4 { packet := RicochetData{}
return nil, errors.New("failed to retrieve new messages from the client")
// Read the four-byte header to get packet length
header := make([]byte, 4)
if _, err := io.ReadAtLeast(reader, header, len(header)); err != nil {
return packet, err
} }
pos := 0 size := int(binary.BigEndian.Uint16(header[0:2]))
finished := false
var datas []RicochetData
for !finished {
size := int(binary.BigEndian.Uint16(buf[pos+0 : pos+2]))
channel := int(binary.BigEndian.Uint16(buf[pos+2 : pos+4]))
if size < 4 { if size < 4 {
return datas, errors.New("invalid ricochet packet received (size=" + strconv.Itoa(size) + ")") return packet, errors.New("invalid packet length")
} }
if pos+size > len(buf) { packet.Channel = int32(binary.BigEndian.Uint16(header[2:4]))
return datas, errors.New("partial data packet received") packet.Data = make([]byte, size-4)
if _, err := io.ReadAtLeast(reader, packet.Data, len(packet.Data)); err != nil {
return packet, err
} }
data := RicochetData{} return packet, nil
data.Channel = int32(channel)
if pos+4 >= len(buf) {
data.Data = make([]byte, 0)
} else {
data.Data = buf[pos+4 : pos+size]
}
datas = append(datas, data)
pos += size
if pos >= len(buf) {
finished = true
}
}
return datas, nil
} }

View File

@ -1,171 +1,105 @@
package utils package utils
import "testing" import (
import "net" "bytes"
import "time" "io"
"testing"
"testing/iotest"
)
type MockConn struct { // Valid packets and their encoded forms
Written []byte var packetTests = []struct {
MockOutput []byte packet RicochetData
encoded []byte
}{
{RicochetData{1, []byte{}}, []byte{0x00, 0x04, 0x00, 0x01}},
{RicochetData{65535, []byte{0xDE, 0xAD, 0xBE, 0xEF}}, []byte{0x00, 0x08, 0xFF, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}},
{RicochetData{2, make([]byte, 65531)}, append([]byte{0xFF, 0xFF, 0x00, 0x02}, make([]byte, 65531)...)},
} }
func (mc *MockConn) Read(b []byte) (int, error) { // Test sending valid packets
copy(b[:], mc.MockOutput[:]) func TestSendRicochetPacket(t *testing.T) {
return len(mc.MockOutput), nil
}
func (mc *MockConn) Write(written []byte) (int, error) {
mc.Written = written
return 0, nil
}
func (mc *MockConn) LocalAddr() net.Addr {
return nil
}
func (mc *MockConn) RemoteAddr() net.Addr {
return nil
}
func (mc *MockConn) Close() error {
return nil
}
func (mc *MockConn) SetDeadline(t time.Time) error {
return nil
}
func (mc *MockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (mc *MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
func TestSentRicochetPacket(t *testing.T) {
conn := new(MockConn)
rni := RicochetNetwork{} rni := RicochetNetwork{}
rni.SendRicochetPacket(conn, 1, []byte{}) for _, td := range packetTests {
if len(conn.Written) != 4 && conn.Written[0] != 0x00 && conn.Written[1] != 0x00 && conn.Written[2] != 0x01 && conn.Written[3] != 0x00 { var buf bytes.Buffer
t.Errorf("Output of SentRicochetPacket was Unexpected: %x", conn.Written) err := rni.SendRicochetPacket(&buf, td.packet.Channel, td.packet.Data)
if err != nil {
t.Errorf("Error sending packet %v: %v", td.packet, err)
} else if !bytes.Equal(buf.Bytes(), td.encoded) {
t.Errorf("Expected serialized packet %x but got %x", td.encoded, buf.Bytes())
}
} }
} }
func TestRecv(t *testing.T) { // Test sending invalid packets
conn := new(MockConn) func TestSendRicochetPacket_Invalid(t *testing.T) {
conn.MockOutput = []byte{0xDE, 0xAD, 0xBE, 0xEF}
rni := RicochetNetwork{} rni := RicochetNetwork{}
buf, err := rni.Recv(conn) invalidPackets := []RicochetData{
if err != nil || len(buf) != 4 || buf[0] != 0xDE || buf[1] != 0xAD || buf[2] != 0xBE || buf[3] != 0xEF { RicochetData{-1, []byte{}},
t.Errorf("Output of Recv was Unexpected: %x", buf) RicochetData{65536, []byte{}},
RicochetData{0, make([]byte, 65532)},
}
for _, td := range invalidPackets {
var buf bytes.Buffer
err := rni.SendRicochetPacket(&buf, td.Channel, td.Data)
// Expect error
if err == nil {
t.Errorf("Expected error when sending invalid packet %v", td)
}
} }
} }
// Test receiving valid packets
func TestRecvRicochetPacket(t *testing.T) { func TestRecvRicochetPacket(t *testing.T) {
conn := new(MockConn) var buf bytes.Buffer
conn.MockOutput = []byte{00, 0x04, 0x00, 0x01} for _, td := range packetTests {
if _, err := buf.Write(td.encoded); err != nil {
rni := RicochetNetwork{} t.Error(err)
rp, err := rni.RecvRicochetPackets(conn)
if err != nil {
t.Errorf("error extracting ricochet packets: %v", err)
return return
} }
if len(rp) != 1 {
t.Errorf("unexpected number of ricochet packets: %d", len(rp))
} else {
if rp[0].Channel != 1 {
t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel)
} }
if len(rp[0].Data) != 0 { // Use a HalfReader to test behavior on short socket reads also
t.Errorf("expected emptry packet, instead got %x", rp[0].Data) reader := iotest.HalfReader(&buf)
}
}
}
func TestRecvRicochetPacketInvalid(t *testing.T) {
conn := new(MockConn)
conn.MockOutput = []byte{00, 0x01, 0x00, 0x01}
rni := RicochetNetwork{} rni := RicochetNetwork{}
_, err := rni.RecvRicochetPackets(conn)
for _, td := range packetTests {
packet, err := rni.RecvRicochetPacket(reader)
if err != nil {
t.Errorf("Error receiving packet %v: %v", td.packet, err)
return
} else if !packet.Equals(td.packet) {
t.Errorf("Expected unserialized packet %v but got %v", td.packet, packet)
}
}
if packet, err := rni.RecvRicochetPacket(reader); err != io.EOF {
if err != nil {
t.Errorf("Expected EOF on packet stream but received error: %v", err)
} else {
t.Errorf("Expected EOF but received packet: %v", packet)
}
}
}
// Test receiving invalid packets
func TestRecvRicochetPacket_Invalid(t *testing.T) {
rni := RicochetNetwork{}
invalidPackets := [][]byte{
[]byte{0x00, 0x00, 0x00, 0x00},
[]byte{0x00, 0x03, 0x00, 0x00},
[]byte{0xff},
[]byte{0x00, 0x06, 0x00, 0x00, 0x00},
[]byte{},
}
for _, td := range invalidPackets {
buf := bytes.NewBuffer(td)
packet, err := rni.RecvRicochetPacket(buf)
// Expect error
if err == nil { if err == nil {
t.Errorf("recv should have errored due to invalid packets %v", err) t.Errorf("Expected error when sending invalid packet %x, got packet %v", td, packet)
}
conn.MockOutput = []byte{00, 0x0A, 0x00, 0x01}
_, err = rni.RecvRicochetPackets(conn)
if err == nil {
t.Errorf("recv should have errored due to invalid packets %v", err)
}
}
func TestRecvRicochetPacketLong(t *testing.T) {
conn := new(MockConn)
conn.MockOutput = []byte{0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}
rni := RicochetNetwork{}
rp, err := rni.RecvRicochetPackets(conn)
if err != nil {
t.Errorf("error extracting ricochet packets: %v", err)
return
}
if len(rp) != 1 {
t.Errorf("unexpected number of ricochet packets: %d", len(rp))
} else {
if rp[0].Channel != 255 {
t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel)
}
if len(rp[0].Data) != 4 || rp[0].Data[0] != 0xDE || rp[0].Data[1] != 0xAD || rp[0].Data[2] != 0xBE || rp[0].Data[3] != 0xEF {
t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data)
} }
} }
}
func TestRecvRicochetPacketMultiplex(t *testing.T) {
conn := new(MockConn)
conn.MockOutput = []byte{0x00, 0x04, 0x00, 0x01, 0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}
rni := RicochetNetwork{}
rp, err := rni.RecvRicochetPackets(conn)
if err != nil {
t.Errorf("error extracting ricochet packets: %v", err)
return
}
if len(rp) != 2 {
t.Errorf("unexpected number of ricochet packets, expected 2 gt: %d", len(rp))
} else {
if rp[0].Channel != 1 {
t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel)
}
if len(rp[0].Data) != 0 {
t.Errorf("expected empty packet, instead got %x", rp[0].Data)
}
if rp[1].Channel != 255 {
t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel)
}
if len(rp[1].Data) != 4 || rp[1].Data[0] != 0xDE || rp[1].Data[1] != 0xAD || rp[1].Data[2] != 0xBE || rp[1].Data[3] != 0xEF {
t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data)
}
}
} }