Files
realy/socketapi/socketapi_test.go
mleku a7944e054c Add unit tests and improve handling in socketapi package
Introduced comprehensive test coverage for `socketapi` functions, including `handleEvent` and `handleReq` logic. Enhanced `publisher` and `handleEvent` to ensure better filter and event validation, and made minor bug fixes to public key matching logic.
2025-06-26 21:20:36 +01:00

235 lines
5.6 KiB
Go

package socketapi
import (
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/fasthttp/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"realy.lol/context"
"realy.lol/event"
"realy.lol/eventid"
"realy.lol/eventidserial"
"realy.lol/filter"
"realy.lol/filters"
"realy.lol/servemux"
"realy.lol/store"
"realy.lol/ws"
)
// MockServer implements the interfaces.Server interface for testing
type MockServer struct {
mock.Mock
}
func (m *MockServer) AcceptEvent(c context.T, ev *event.T, hr *http.Request, remote string) (accept bool, notice string, afterSave func()) {
args := m.Called(c, ev, hr, remote)
return args.Bool(0), args.String(1), args.Get(2).(func())
}
func (m *MockServer) AcceptReq(c context.T, hr *http.Request, id []byte, f *filters.T, remote string) (allowed *filters.T, ok bool, modified bool) {
args := m.Called(c, hr, id, f, remote)
return args.Get(0).(*filters.T), args.Bool(1), args.Bool(2)
}
func (m *MockServer) AddEvent(c context.T, ev *event.T, hr *http.Request, remote string) (accepted bool, message []byte) {
args := m.Called(c, ev, hr, remote)
return args.Bool(0), args.Get(1).([]byte)
}
func (m *MockServer) Context() context.T {
args := m.Called()
return args.Get(0).(context.T)
}
func (m *MockServer) HandleRelayInfo(w http.ResponseWriter, r *http.Request) {
m.Called(w, r)
}
func (m *MockServer) Lock() {
m.Called()
}
func (m *MockServer) ServiceURL(req *http.Request) string {
args := m.Called(req)
return args.String(0)
}
func (m *MockServer) Shutdown() {
m.Called()
}
func (m *MockServer) Storage() store.I {
args := m.Called()
if args.Get(0) == nil {
return nil
}
return args.Get(0).(store.I)
}
func (m *MockServer) Unlock() {
m.Called()
}
// MockStore implements the store.I interface for testing
type MockStore struct {
mock.Mock
}
func (m *MockStore) Init(path string) error {
args := m.Called(path)
return args.Error(0)
}
func (m *MockStore) Path() string {
args := m.Called()
return args.String(0)
}
func (m *MockStore) Close() error {
args := m.Called()
return args.Error(0)
}
func (m *MockStore) Nuke() error {
args := m.Called()
return args.Error(0)
}
func (m *MockStore) QueryEvents(c context.T, f *filter.T) (event.Ts, error) {
args := m.Called(c, f)
return args.Get(0).(event.Ts), args.Error(1)
}
func (m *MockStore) DeleteEvent(c context.T, ev *eventid.T, noTombstone ...bool) error {
args := m.Called(c, ev, noTombstone)
return args.Error(0)
}
func (m *MockStore) SaveEvent(c context.T, ev *event.T) error {
args := m.Called(c, ev)
return args.Error(0)
}
func (m *MockStore) Import(r io.Reader) chan struct{} {
args := m.Called(r)
return args.Get(0).(chan struct{})
}
func (m *MockStore) Export(c context.T, w io.Writer, pubkeys ...[]byte) {
m.Called(c, w, pubkeys)
}
func (m *MockStore) Sync() error {
args := m.Called()
return args.Error(0)
}
func (m *MockStore) SetLogLevel(level string) {
m.Called(level)
}
func (m *MockStore) EventIdsBySerial(start uint64, count int) ([]eventidserial.E, error) {
args := m.Called(start, count)
return args.Get(0).([]eventidserial.E), args.Error(1)
}
func (m *MockStore) EventCount() (uint64, error) {
args := m.Called()
return args.Get(0).(uint64), args.Error(1)
}
func TestNew(t *testing.T) {
mockServer := &MockServer{}
sm := servemux.New()
// Test creating new socketapi handler
New(mockServer, "/ws", sm)
// Verify the handler was registered
assert.NotNil(t, sm)
}
func TestGetListener(t *testing.T) {
// Create a test WebSocket connection
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := Upgrader.Upgrade(w, r, nil)
require.NoError(t, err)
defer conn.Close()
// Test GetListener function
listener := GetListener(conn, r)
assert.NotNil(t, listener)
assert.Equal(t, conn, listener.Conn)
assert.Equal(t, r, listener.Request)
}))
defer server.Close()
// Convert HTTP URL to WebSocket URL
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
// Create WebSocket client connection
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
require.NoError(t, err)
defer conn.Close()
}
func TestUpgrader(t *testing.T) {
// Test that the upgrader is configured correctly
assert.Equal(t, 1024, Upgrader.ReadBufferSize)
assert.Equal(t, 1024, Upgrader.WriteBufferSize)
assert.NotNil(t, Upgrader.CheckOrigin)
// Test CheckOrigin function
req := &http.Request{}
assert.True(t, Upgrader.CheckOrigin(req))
}
func TestA_Context(t *testing.T) {
mockServer := &MockServer{}
ctx := context.Bg()
mockServer.On("Context").Return(ctx)
a := &A{Server: mockServer}
result := a.Context()
assert.Equal(t, ctx, result)
mockServer.AssertExpectations(t)
}
func TestA_HandleMessage_UnknownEnvelope(t *testing.T) {
mockServer := &MockServer{}
a := &A{Server: mockServer}
// Create a mock listener
mockListener := &ws.Listener{}
a.Listener = mockListener
// Test with invalid message
invalidMsg := []byte(`["INVALID", "test"]`)
// This should not panic and should handle the unknown envelope gracefully
a.HandleMessage(invalidMsg, "127.0.0.1")
}
func TestConstants(t *testing.T) {
// Test that constants are set to reasonable values
assert.Equal(t, 10*time.Second, DefaultWriteWait)
assert.Equal(t, 60*time.Second, DefaultPongWait)
assert.Equal(t, DefaultPongWait/2, DefaultPingWait)
assert.Equal(t, 1000000, DefaultMaxMessageSize) // 1MB (base 10)
}
func TestChallengeConstants(t *testing.T) {
// Test challenge constants
assert.Equal(t, "nchal", DefaultChallengeHRP)
assert.Equal(t, 16, DefaultChallengeLength)
}