fix off by one error when removing ricochet instances #11
|
@ -98,14 +98,12 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService
|
|||
return
|
||||
}
|
||||
|
||||
packets, err := r.rni.RecvRicochetPackets(oc.conn)
|
||||
|
||||
packet, err := r.rni.RecvRicochetPacket(oc.conn)
|
||||
if err != nil {
|
||||
oc.Close()
|
||||
return
|
||||
}
|
||||
|
||||
for _, packet := range packets {
|
||||
|
||||
if len(packet.Data) == 0 {
|
||||
service.OnChannelClosed(oc, packet.Channel)
|
||||
continue
|
||||
|
@ -315,7 +313,6 @@ func (r *Ricochet) processConnection(oc *OpenConnection, service RicochetService
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// negotiateVersion Perform version negotiation with the connected host.
|
||||
func (r *Ricochet) negotiateVersion(conn net.Conn, outbound bool) (*OpenConnection, error) {
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"strconv"
|
||||
"io"
|
||||
)
|
||||
|
||||
// RicochetData is a structure containing the raw data and the channel it the
|
||||
|
@ -14,79 +14,67 @@ type RicochetData struct {
|
|||
Data []byte
|
||||
}
|
||||
|
||||
func (rd RicochetData) Equals(other RicochetData) bool {
|
||||
return rd.Channel == other.Channel && bytes.Equal(rd.Data, other.Data)
|
||||
}
|
||||
|
||||
// RicochetNetworkInterface abstract operations that interact with ricochet's
|
||||
// packet layer.
|
||||
type RicochetNetworkInterface interface {
|
||||
Recv(conn net.Conn) ([]byte, error)
|
||||
SendRicochetPacket(conn net.Conn, channel int32, data []byte)
|
||||
RecvRicochetPackets(conn net.Conn) ([]RicochetData, error)
|
||||
SendRicochetPacket(dst io.Writer, channel int32, data []byte) error
|
||||
RecvRicochetPacket(reader io.Reader) (RicochetData, error)
|
||||
}
|
||||
|
||||
// RicochetNetwork is a concrete implementation of the RicochetNetworkInterface
|
||||
type RicochetNetwork struct {
|
||||
}
|
||||
|
||||
// Recv reads data from the client, and returns the raw byte array, else error.
|
||||
func (rn *RicochetNetwork) Recv(conn net.Conn) ([]byte, error) {
|
||||
buf := make([]byte, 4096)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ret := make([]byte, n)
|
||||
copy(ret[:], buf[:])
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// SendRicochetPacket places the data into a structure needed for the client to
|
||||
// decode the packet and writes the packet to the network.
|
||||
func (rn *RicochetNetwork) SendRicochetPacket(conn net.Conn, channel int32, data []byte) {
|
||||
header := make([]byte, 4+len(data))
|
||||
header[0] = byte(len(header) >> 8)
|
||||
header[1] = byte(len(header) & 0x00FF)
|
||||
header[2] = 0x00
|
||||
header[3] = byte(channel)
|
||||
copy(header[4:], data[:])
|
||||
conn.Write(header)
|
||||
func (rn *RicochetNetwork) SendRicochetPacket(dst io.Writer, channel int32, data []byte) error {
|
||||
packet := make([]byte, 4+len(data))
|
||||
if len(packet) > 65535 {
|
||||
return errors.New("packet too large")
|
||||
}
|
||||
binary.BigEndian.PutUint16(packet[0:2], uint16(len(packet)))
|
||||
if channel < 0 || channel > 65535 {
|
||||
return errors.New("invalid channel ID")
|
||||
}
|
||||
binary.BigEndian.PutUint16(packet[2:4], uint16(channel))
|
||||
copy(packet[4:], data[:])
|
||||
|
||||
for pos := 0; pos < len(packet); {
|
||||
n, err := dst.Write(packet[pos:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pos += n
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecvRicochetPackets returns an array of new messages received from the ricochet client
|
||||
func (rn *RicochetNetwork) RecvRicochetPackets(conn net.Conn) ([]RicochetData, error) {
|
||||
buf, err := rn.Recv(conn)
|
||||
if err != nil && len(buf) < 4 {
|
||||
return nil, errors.New("failed to retrieve new messages from the client")
|
||||
// RecvRicochetPacket returns the next packet from reader as a RicochetData
|
||||
// structure, or an error.
|
||||
func (rn *RicochetNetwork) RecvRicochetPacket(reader io.Reader) (RicochetData, error) {
|
||||
packet := RicochetData{}
|
||||
|
||||
// Read the four-byte header to get packet length
|
||||
header := make([]byte, 4)
|
||||
if _, err := io.ReadAtLeast(reader, header, len(header)); err != nil {
|
||||
return packet, err
|
||||
}
|
||||
|
||||
pos := 0
|
||||
finished := false
|
||||
var datas []RicochetData
|
||||
|
||||
for !finished {
|
||||
size := int(binary.BigEndian.Uint16(buf[pos+0 : pos+2]))
|
||||
channel := int(binary.BigEndian.Uint16(buf[pos+2 : pos+4]))
|
||||
|
||||
size := int(binary.BigEndian.Uint16(header[0:2]))
|
||||
if size < 4 {
|
||||
return datas, errors.New("invalid ricochet packet received (size=" + strconv.Itoa(size) + ")")
|
||||
return packet, errors.New("invalid packet length")
|
||||
}
|
||||
|
||||
if pos+size > len(buf) {
|
||||
return datas, errors.New("partial data packet received")
|
||||
packet.Channel = int32(binary.BigEndian.Uint16(header[2:4]))
|
||||
packet.Data = make([]byte, size-4)
|
||||
|
||||
if _, err := io.ReadAtLeast(reader, packet.Data, len(packet.Data)); err != nil {
|
||||
return packet, err
|
||||
}
|
||||
|
||||
data := RicochetData{}
|
||||
data.Channel = int32(channel)
|
||||
|
||||
if pos+4 >= len(buf) {
|
||||
data.Data = make([]byte, 0)
|
||||
} else {
|
||||
data.Data = buf[pos+4 : pos+size]
|
||||
}
|
||||
|
||||
datas = append(datas, data)
|
||||
pos += size
|
||||
if pos >= len(buf) {
|
||||
finished = true
|
||||
}
|
||||
}
|
||||
return datas, nil
|
||||
return packet, nil
|
||||
}
|
||||
|
|
|
@ -1,171 +1,105 @@
|
|||
package utils
|
||||
|
||||
import "testing"
|
||||
import "net"
|
||||
import "time"
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
"testing/iotest"
|
||||
)
|
||||
|
||||
type MockConn struct {
|
||||
Written []byte
|
||||
MockOutput []byte
|
||||
// Valid packets and their encoded forms
|
||||
var packetTests = []struct {
|
||||
packet RicochetData
|
||||
encoded []byte
|
||||
}{
|
||||
{RicochetData{1, []byte{}}, []byte{0x00, 0x04, 0x00, 0x01}},
|
||||
{RicochetData{65535, []byte{0xDE, 0xAD, 0xBE, 0xEF}}, []byte{0x00, 0x08, 0xFF, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}},
|
||||
{RicochetData{2, make([]byte, 65531)}, append([]byte{0xFF, 0xFF, 0x00, 0x02}, make([]byte, 65531)...)},
|
||||
}
|
||||
|
||||
func (mc *MockConn) Read(b []byte) (int, error) {
|
||||
copy(b[:], mc.MockOutput[:])
|
||||
return len(mc.MockOutput), nil
|
||||
}
|
||||
|
||||
func (mc *MockConn) Write(written []byte) (int, error) {
|
||||
mc.Written = written
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (mc *MockConn) LocalAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *MockConn) RemoteAddr() net.Addr {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *MockConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *MockConn) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *MockConn) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *MockConn) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSentRicochetPacket(t *testing.T) {
|
||||
conn := new(MockConn)
|
||||
// Test sending valid packets
|
||||
func TestSendRicochetPacket(t *testing.T) {
|
||||
rni := RicochetNetwork{}
|
||||
rni.SendRicochetPacket(conn, 1, []byte{})
|
||||
if len(conn.Written) != 4 && conn.Written[0] != 0x00 && conn.Written[1] != 0x00 && conn.Written[2] != 0x01 && conn.Written[3] != 0x00 {
|
||||
t.Errorf("Output of SentRicochetPacket was Unexpected: %x", conn.Written)
|
||||
for _, td := range packetTests {
|
||||
var buf bytes.Buffer
|
||||
err := rni.SendRicochetPacket(&buf, td.packet.Channel, td.packet.Data)
|
||||
if err != nil {
|
||||
t.Errorf("Error sending packet %v: %v", td.packet, err)
|
||||
} else if !bytes.Equal(buf.Bytes(), td.encoded) {
|
||||
t.Errorf("Expected serialized packet %x but got %x", td.encoded, buf.Bytes())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecv(t *testing.T) {
|
||||
conn := new(MockConn)
|
||||
conn.MockOutput = []byte{0xDE, 0xAD, 0xBE, 0xEF}
|
||||
// Test sending invalid packets
|
||||
func TestSendRicochetPacket_Invalid(t *testing.T) {
|
||||
rni := RicochetNetwork{}
|
||||
buf, err := rni.Recv(conn)
|
||||
if err != nil || len(buf) != 4 || buf[0] != 0xDE || buf[1] != 0xAD || buf[2] != 0xBE || buf[3] != 0xEF {
|
||||
t.Errorf("Output of Recv was Unexpected: %x", buf)
|
||||
invalidPackets := []RicochetData{
|
||||
RicochetData{-1, []byte{}},
|
||||
RicochetData{65536, []byte{}},
|
||||
RicochetData{0, make([]byte, 65532)},
|
||||
}
|
||||
|
||||
for _, td := range invalidPackets {
|
||||
var buf bytes.Buffer
|
||||
err := rni.SendRicochetPacket(&buf, td.Channel, td.Data)
|
||||
// Expect error
|
||||
if err == nil {
|
||||
t.Errorf("Expected error when sending invalid packet %v", td)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test receiving valid packets
|
||||
func TestRecvRicochetPacket(t *testing.T) {
|
||||
conn := new(MockConn)
|
||||
conn.MockOutput = []byte{00, 0x04, 0x00, 0x01}
|
||||
|
||||
rni := RicochetNetwork{}
|
||||
rp, err := rni.RecvRicochetPackets(conn)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error extracting ricochet packets: %v", err)
|
||||
var buf bytes.Buffer
|
||||
for _, td := range packetTests {
|
||||
if _, err := buf.Write(td.encoded); err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(rp) != 1 {
|
||||
t.Errorf("unexpected number of ricochet packets: %d", len(rp))
|
||||
} else {
|
||||
if rp[0].Channel != 1 {
|
||||
t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel)
|
||||
}
|
||||
|
||||
if len(rp[0].Data) != 0 {
|
||||
t.Errorf("expected emptry packet, instead got %x", rp[0].Data)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRecvRicochetPacketInvalid(t *testing.T) {
|
||||
conn := new(MockConn)
|
||||
conn.MockOutput = []byte{00, 0x01, 0x00, 0x01}
|
||||
|
||||
// Use a HalfReader to test behavior on short socket reads also
|
||||
reader := iotest.HalfReader(&buf)
|
||||
rni := RicochetNetwork{}
|
||||
_, err := rni.RecvRicochetPackets(conn)
|
||||
|
||||
for _, td := range packetTests {
|
||||
packet, err := rni.RecvRicochetPacket(reader)
|
||||
if err != nil {
|
||||
t.Errorf("Error receiving packet %v: %v", td.packet, err)
|
||||
return
|
||||
} else if !packet.Equals(td.packet) {
|
||||
t.Errorf("Expected unserialized packet %v but got %v", td.packet, packet)
|
||||
}
|
||||
}
|
||||
|
||||
if packet, err := rni.RecvRicochetPacket(reader); err != io.EOF {
|
||||
if err != nil {
|
||||
t.Errorf("Expected EOF on packet stream but received error: %v", err)
|
||||
} else {
|
||||
t.Errorf("Expected EOF but received packet: %v", packet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test receiving invalid packets
|
||||
func TestRecvRicochetPacket_Invalid(t *testing.T) {
|
||||
rni := RicochetNetwork{}
|
||||
invalidPackets := [][]byte{
|
||||
[]byte{0x00, 0x00, 0x00, 0x00},
|
||||
[]byte{0x00, 0x03, 0x00, 0x00},
|
||||
[]byte{0xff},
|
||||
[]byte{0x00, 0x06, 0x00, 0x00, 0x00},
|
||||
[]byte{},
|
||||
}
|
||||
|
||||
for _, td := range invalidPackets {
|
||||
buf := bytes.NewBuffer(td)
|
||||
packet, err := rni.RecvRicochetPacket(buf)
|
||||
// Expect error
|
||||
if err == nil {
|
||||
t.Errorf("recv should have errored due to invalid packets %v", err)
|
||||
}
|
||||
|
||||
conn.MockOutput = []byte{00, 0x0A, 0x00, 0x01}
|
||||
|
||||
_, err = rni.RecvRicochetPackets(conn)
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("recv should have errored due to invalid packets %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRecvRicochetPacketLong(t *testing.T) {
|
||||
conn := new(MockConn)
|
||||
conn.MockOutput = []byte{0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}
|
||||
|
||||
rni := RicochetNetwork{}
|
||||
rp, err := rni.RecvRicochetPackets(conn)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error extracting ricochet packets: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(rp) != 1 {
|
||||
t.Errorf("unexpected number of ricochet packets: %d", len(rp))
|
||||
} else {
|
||||
if rp[0].Channel != 255 {
|
||||
t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel)
|
||||
}
|
||||
|
||||
if len(rp[0].Data) != 4 || rp[0].Data[0] != 0xDE || rp[0].Data[1] != 0xAD || rp[0].Data[2] != 0xBE || rp[0].Data[3] != 0xEF {
|
||||
t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data)
|
||||
t.Errorf("Expected error when sending invalid packet %x, got packet %v", td, packet)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRecvRicochetPacketMultiplex(t *testing.T) {
|
||||
conn := new(MockConn)
|
||||
conn.MockOutput = []byte{0x00, 0x04, 0x00, 0x01, 0x00, 0x08, 0x00, 0xFF, 0xDE, 0xAD, 0xBE, 0xEF}
|
||||
|
||||
rni := RicochetNetwork{}
|
||||
rp, err := rni.RecvRicochetPackets(conn)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error extracting ricochet packets: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(rp) != 2 {
|
||||
t.Errorf("unexpected number of ricochet packets, expected 2 gt: %d", len(rp))
|
||||
} else {
|
||||
|
||||
if rp[0].Channel != 1 {
|
||||
t.Errorf("channel number is Unexpected expected 1: %d", rp[0].Channel)
|
||||
}
|
||||
|
||||
if len(rp[0].Data) != 0 {
|
||||
t.Errorf("expected empty packet, instead got %x", rp[0].Data)
|
||||
}
|
||||
|
||||
if rp[1].Channel != 255 {
|
||||
t.Errorf("channel number is Unexpected expected 255 got: %d", rp[0].Channel)
|
||||
}
|
||||
|
||||
if len(rp[1].Data) != 4 || rp[1].Data[0] != 0xDE || rp[1].Data[1] != 0xAD || rp[1].Data[2] != 0xBE || rp[1].Data[3] != 0xEF {
|
||||
t.Errorf("expected 0xDEADBEEF packet, instead got %x", rp[0].Data)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Reference in New Issue