82 lines
2.1 KiB
Go
82 lines
2.1 KiB
Go
package ws
|
|
|
|
import (
|
|
"net/http/httptest"
|
|
"testing"
|
|
)
|
|
|
|
// Test the remote address parsing logic which is the main testable functionality
|
|
func TestListenerSetRemoteFromReq(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
remoteAddr string
|
|
xForwardedFor string
|
|
expectedRemote string
|
|
}{
|
|
{
|
|
name: "No X-Forwarded-For header",
|
|
remoteAddr: "192.168.1.1:8080",
|
|
xForwardedFor: "",
|
|
expectedRemote: "192.168.1.1:8080",
|
|
},
|
|
{
|
|
name: "Single X-Forwarded-For",
|
|
remoteAddr: "192.168.1.1:8080",
|
|
xForwardedFor: "203.0.113.1",
|
|
expectedRemote: "203.0.113.1",
|
|
},
|
|
{
|
|
name: "Double X-Forwarded-For",
|
|
remoteAddr: "192.168.1.1:8080",
|
|
xForwardedFor: "203.0.113.1 198.51.100.1",
|
|
expectedRemote: "198.51.100.1",
|
|
},
|
|
{
|
|
name: "Empty X-Forwarded-For",
|
|
remoteAddr: "192.168.1.1:8080",
|
|
xForwardedFor: "",
|
|
expectedRemote: "192.168.1.1:8080",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create a test listener to access the setRemoteFromReq method
|
|
listener := &Listener{}
|
|
|
|
req := httptest.NewRequest("GET", "/", nil)
|
|
req.RemoteAddr = tt.remoteAddr
|
|
if tt.xForwardedFor != "" {
|
|
req.Header.Set("X-Forwarded-For", tt.xForwardedFor)
|
|
}
|
|
|
|
// Call the method we want to test
|
|
listener.setRemoteFromReq(req)
|
|
|
|
if listener.RealRemote() != tt.expectedRemote {
|
|
t.Errorf("Expected remote address to be '%s', got '%s'", tt.expectedRemote, listener.RealRemote())
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Test that Req() returns the stored request
|
|
func TestListenerReq(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
listener := &Listener{Request: req}
|
|
|
|
if listener.Req() != req {
|
|
t.Error("Expected Req() to return the stored request")
|
|
}
|
|
}
|
|
|
|
// Test RealRemote method
|
|
func TestListenerRealRemote(t *testing.T) {
|
|
listener := &Listener{}
|
|
listener.remote.Store("test-remote-address")
|
|
|
|
if listener.RealRemote() != "test-remote-address" {
|
|
t.Errorf("Expected RealRemote() to return 'test-remote-address', got '%s'", listener.RealRemote())
|
|
}
|
|
}
|