diff --git a/go-socks-lb/.gitignore b/go-socks-lb/.gitignore new file mode 100644 index 0000000..0026861 --- /dev/null +++ b/go-socks-lb/.gitignore @@ -0,0 +1,22 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe diff --git a/go-socks-lb/LICENSE b/go-socks-lb/LICENSE new file mode 100644 index 0000000..a5df10e --- /dev/null +++ b/go-socks-lb/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2014 Armon Dadgar + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/go-socks-lb/README.md b/go-socks-lb/README.md new file mode 100644 index 0000000..9cd1563 --- /dev/null +++ b/go-socks-lb/README.md @@ -0,0 +1,45 @@ +go-socks5 [![Build Status](https://travis-ci.org/armon/go-socks5.png)](https://travis-ci.org/armon/go-socks5) +========= + +Provides the `socks5` package that implements a [SOCKS5 server](http://en.wikipedia.org/wiki/SOCKS). +SOCKS (Secure Sockets) is used to route traffic between a client and server through +an intermediate proxy layer. This can be used to bypass firewalls or NATs. + +Feature +======= + +The package has the following features: +* "No Auth" mode +* User/Password authentication +* Support for the CONNECT command +* Rules to do granular filtering of commands +* Custom DNS resolution +* Unit tests + +TODO +==== + +The package still needs the following: +* Support for the BIND command +* Support for the ASSOCIATE command + + +Example +======= + +Below is a simple example of usage + +```go +// Create a SOCKS5 server +conf := &socks5.Config{} +server, err := socks5.New(conf) +if err != nil { + panic(err) +} + +// Create SOCKS5 proxy on localhost port 8000 +if err := server.ListenAndServe("tcp", "127.0.0.1:8000"); err != nil { + panic(err) +} +``` + diff --git a/go-socks-lb/auth.go b/go-socks-lb/auth.go new file mode 100644 index 0000000..7582d79 --- /dev/null +++ b/go-socks-lb/auth.go @@ -0,0 +1,151 @@ +package main + +import ( + "fmt" + "io" +) + +const ( + NoAuth = uint8(0) + noAcceptable = uint8(255) + UserPassAuth = uint8(2) + userAuthVersion = uint8(1) + authSuccess = uint8(0) + authFailure = uint8(1) +) + +var ( + UserAuthFailed = fmt.Errorf("User authentication failed") + NoSupportedAuth = fmt.Errorf("No supported authentication mechanism") +) + +// A Request encapsulates authentication state provided +// during negotiation +type AuthContext struct { + // Provided auth method + Method uint8 + // Payload provided during negotiation. + // Keys depend on the used auth method. + // For UserPassauth contains Username + Payload map[string]string +} + +type Authenticator interface { + Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) + GetCode() uint8 +} + +// NoAuthAuthenticator is used to handle the "No Authentication" mode +type NoAuthAuthenticator struct{} + +func (a NoAuthAuthenticator) GetCode() uint8 { + return NoAuth +} + +func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { + _, err := writer.Write([]byte{socks5Version, NoAuth}) + return &AuthContext{NoAuth, nil}, err +} + +// UserPassAuthenticator is used to handle username/password based +// authentication +type UserPassAuthenticator struct { + Credentials CredentialStore +} + +func (a UserPassAuthenticator) GetCode() uint8 { + return UserPassAuth +} + +func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { + // Tell the client to use user/pass auth + if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil { + return nil, err + } + + // Get the version and username length + header := []byte{0, 0} + if _, err := io.ReadAtLeast(reader, header, 2); err != nil { + return nil, err + } + + // Ensure we are compatible + if header[0] != userAuthVersion { + return nil, fmt.Errorf("Unsupported auth version: %v", header[0]) + } + + // Get the user name + userLen := int(header[1]) + user := make([]byte, userLen) + if _, err := io.ReadAtLeast(reader, user, userLen); err != nil { + return nil, err + } + + // Get the password length + if _, err := reader.Read(header[:1]); err != nil { + return nil, err + } + + // Get the password + passLen := int(header[0]) + pass := make([]byte, passLen) + if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil { + return nil, err + } + + // Verify the password + if a.Credentials.Valid(string(user), string(pass)) { + if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil { + return nil, err + } + } else { + if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil { + return nil, err + } + return nil, UserAuthFailed + } + + // Done + return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil +} + +// authenticate is used to handle connection authentication +func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) { + // Get the methods + methods, err := readMethods(bufConn) + if err != nil { + return nil, fmt.Errorf("Failed to get auth methods: %v", err) + } + + // Select a usable method + for _, method := range methods { + cator, found := s.authMethods[method] + if found { + return cator.Authenticate(bufConn, conn) + } + } + + // No usable method found + return nil, noAcceptableAuth(conn) +} + +// noAcceptableAuth is used to handle when we have no eligible +// authentication mechanism +func noAcceptableAuth(conn io.Writer) error { + conn.Write([]byte{socks5Version, noAcceptable}) + return NoSupportedAuth +} + +// readMethods is used to read the number of methods +// and proceeding auth methods +func readMethods(r io.Reader) ([]byte, error) { + header := []byte{0} + if _, err := r.Read(header); err != nil { + return nil, err + } + + numMethods := int(header[0]) + methods := make([]byte, numMethods) + _, err := io.ReadAtLeast(r, methods, numMethods) + return methods, err +} diff --git a/go-socks-lb/auth_test.go b/go-socks-lb/auth_test.go new file mode 100644 index 0000000..c8afd9f --- /dev/null +++ b/go-socks-lb/auth_test.go @@ -0,0 +1,119 @@ +package main + +import ( + "bytes" + "testing" +) + +func TestNoAuth(t *testing.T) { + req := bytes.NewBuffer(nil) + req.Write([]byte{1, NoAuth}) + var resp bytes.Buffer + + s, _ := New(&Config{}) + ctx, err := s.authenticate(&resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + if ctx.Method != NoAuth { + t.Fatal("Invalid Context Method") + } + + out := resp.Bytes() + if !bytes.Equal(out, []byte{socks5Version, NoAuth}) { + t.Fatalf("bad: %v", out) + } +} + +func TestPasswordAuth_Valid(t *testing.T) { + req := bytes.NewBuffer(nil) + req.Write([]byte{2, NoAuth, UserPassAuth}) + req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'}) + var resp bytes.Buffer + + cred := StaticCredentials{ + "foo": "bar", + } + + cator := UserPassAuthenticator{Credentials: cred} + + s, _ := New(&Config{AuthMethods: []Authenticator{cator}}) + + ctx, err := s.authenticate(&resp, req) + if err != nil { + t.Fatalf("err: %v", err) + } + + if ctx.Method != UserPassAuth { + t.Fatal("Invalid Context Method") + } + + val, ok := ctx.Payload["Username"] + if !ok { + t.Fatal("Missing key Username in auth context's payload") + } + + if val != "foo" { + t.Fatal("Invalid Username in auth context's payload") + } + + out := resp.Bytes() + if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authSuccess}) { + t.Fatalf("bad: %v", out) + } +} + +func TestPasswordAuth_Invalid(t *testing.T) { + req := bytes.NewBuffer(nil) + req.Write([]byte{2, NoAuth, UserPassAuth}) + req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'}) + var resp bytes.Buffer + + cred := StaticCredentials{ + "foo": "bar", + } + cator := UserPassAuthenticator{Credentials: cred} + s, _ := New(&Config{AuthMethods: []Authenticator{cator}}) + + ctx, err := s.authenticate(&resp, req) + if err != UserAuthFailed { + t.Fatalf("err: %v", err) + } + + if ctx != nil { + t.Fatal("Invalid Context Method") + } + + out := resp.Bytes() + if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authFailure}) { + t.Fatalf("bad: %v", out) + } +} + +func TestNoSupportedAuth(t *testing.T) { + req := bytes.NewBuffer(nil) + req.Write([]byte{1, NoAuth}) + var resp bytes.Buffer + + cred := StaticCredentials{ + "foo": "bar", + } + cator := UserPassAuthenticator{Credentials: cred} + + s, _ := New(&Config{AuthMethods: []Authenticator{cator}}) + + ctx, err := s.authenticate(&resp, req) + if err != NoSupportedAuth { + t.Fatalf("err: %v", err) + } + + if ctx != nil { + t.Fatal("Invalid Context Method") + } + + out := resp.Bytes() + if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) { + t.Fatalf("bad: %v", out) + } +} diff --git a/go-socks-lb/config.yml b/go-socks-lb/config.yml new file mode 100644 index 0000000..d6ead0e --- /dev/null +++ b/go-socks-lb/config.yml @@ -0,0 +1,9 @@ +proxy: + - url: "socks5://192.168.122.128:1080" + weight: 5 + - url: "socks5://192.168.122.128:1081" + weight: 5 + - url: "socks5://192.168.122.128:1082" + weight: 5 +sticky: true +cache-clean-interval: 10 diff --git a/go-socks-lb/credentials.go b/go-socks-lb/credentials.go new file mode 100644 index 0000000..56da371 --- /dev/null +++ b/go-socks-lb/credentials.go @@ -0,0 +1,17 @@ +package main + +// CredentialStore is used to support user/pass authentication +type CredentialStore interface { + Valid(user, password string) bool +} + +// StaticCredentials enables using a map directly as a credential store +type StaticCredentials map[string]string + +func (s StaticCredentials) Valid(user, password string) bool { + pass, ok := s[user] + if !ok { + return false + } + return password == pass +} diff --git a/go-socks-lb/credentials_test.go b/go-socks-lb/credentials_test.go new file mode 100644 index 0000000..1c8b9fb --- /dev/null +++ b/go-socks-lb/credentials_test.go @@ -0,0 +1,24 @@ +package main + +import ( + "testing" +) + +func TestStaticCredentials(t *testing.T) { + creds := StaticCredentials{ + "foo": "bar", + "baz": "", + } + + if !creds.Valid("foo", "bar") { + t.Fatalf("expect valid") + } + + if !creds.Valid("baz", "") { + t.Fatalf("expect valid") + } + + if creds.Valid("foo", "") { + t.Fatalf("expect invalid") + } +} diff --git a/go-socks-lb/go-socks-lb b/go-socks-lb/go-socks-lb new file mode 100755 index 0000000..af7c7e9 Binary files /dev/null and b/go-socks-lb/go-socks-lb differ diff --git a/go-socks-lb/go.mod b/go-socks-lb/go.mod new file mode 100644 index 0000000..0eb104f --- /dev/null +++ b/go-socks-lb/go.mod @@ -0,0 +1,9 @@ +module git.hmthsn.com/mantao/mop/go-socks-lb + +go 1.14 + +require ( + github.com/mroth/weightedrand v0.2.1 + golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e + gopkg.in/yaml.v2 v2.2.8 +) diff --git a/go-socks-lb/go.sum b/go-socks-lb/go.sum new file mode 100644 index 0000000..0209e91 --- /dev/null +++ b/go-socks-lb/go.sum @@ -0,0 +1,11 @@ +github.com/mroth/weightedrand v0.2.1 h1:ivJastXlhBrj0q931DJ8IwhOLGwrYtPeENWd3WlVI0s= +github.com/mroth/weightedrand v0.2.1/go.mod h1:3p2SIcC8al1YMzGhAIoXD+r9olo/g/cdJgAD905gyNE= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e h1:3G+cUijn7XD+S4eJFddp53Pv7+slrESplyjG25HgL+k= +golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/go-socks-lb/main.go b/go-socks-lb/main.go new file mode 100644 index 0000000..31a1a53 --- /dev/null +++ b/go-socks-lb/main.go @@ -0,0 +1,207 @@ +package main + +import ( + "flag" + "fmt" + "math/rand" + "net/http" + "net/url" + "os" + "sync" + "time" + + wr "github.com/mroth/weightedrand" + "golang.org/x/net/proxy" + "gopkg.in/yaml.v2" +) + +const ( + URL = "http://connectivitycheck.gstatic.com/generate_204" +) + +func proxyTestStatusCode(proxyURL string, URL string, StatusCode int) bool { + // create a socks5 dialer + u, err := url.Parse(proxyURL) + if err != nil { + fmt.Println("error parsing") + return false + } + dialer, err := proxy.FromURL(u, proxy.Direct) + if err != nil { + fmt.Fprintln(os.Stderr, "can't connect to the proxy:", err) + return false + } + // setup a http client + httpTransport := &http.Transport{} + httpClient := &http.Client{ + Transport: httpTransport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + fmt.Printf("redirect %v", *req) + return http.ErrUseLastResponse + }, + } + // set our socks5 as the dialer + httpTransport.Dial = dialer.Dial + // create a request + req, err := http.NewRequest("GET", URL, nil) + if err != nil { + fmt.Fprintln(os.Stderr, "can't create request:", err) + return false + } + // use the http client to fetch the page + resp, err := httpClient.Do(req) + if err != nil { + fmt.Fprintln(os.Stderr, "can't GET page:", err) + return false + } + // defer resp.Body.Close() + // b, err := ioutil.ReadAll(resp.Body) + // if err != nil { + // fmt.Fprintln(os.Stderr, "error reading body:", err) + // return false + // } + return resp.StatusCode == StatusCode +} + +func testTask(testfn func() bool, interval int, stop chan bool) { + for { + go testfn() + select { + case <-stop: + fmt.Println("stopping") + return + default: + } + time.Sleep(time.Duration(interval) * time.Second) + } +} + +type proxyInst struct { + URL string + Weight int + StatusHistory []bool +} +type ProxyManager struct { + Proxys []proxyInst + Cache map[string]int + Sticky bool + CacheCleanInterval int + Chooser wr.Chooser + mux sync.Mutex +} +type proxyConf struct { + Proxys []struct { + URL string `yaml:"url"` + Weight int `yaml:"weight"` + } `yaml:"proxy"` + Sticky bool `yaml:"sticky"` + CacheCleanInterval int `yaml:"cache-clean-interval"` +} + +func NewPM(confFp string) (*ProxyManager, error) { + pm := ProxyManager{} + f, err := os.Open(confFp) + if err != nil { + return &ProxyManager{}, err + } + defer f.Close() + + var cfg proxyConf + decoder := yaml.NewDecoder(f) + err = decoder.Decode(&cfg) + if err != nil { + return &ProxyManager{}, err + } + var chooseArr []wr.Choice + for idx, pc := range cfg.Proxys { + var pi proxyInst + pi.URL = pc.URL + pi.Weight = pc.Weight + pm.Proxys = append(pm.Proxys, pi) + chooseArr = append(chooseArr, wr.Choice{Item: idx, Weight: uint(pi.Weight)}) + } + pm.Sticky = cfg.Sticky + pm.CacheCleanInterval = cfg.CacheCleanInterval + pm.Cache = make(map[string]int) + pm.Chooser = wr.NewChooser(chooseArr...) + go pm.keepClearingCache() + return &pm, nil +} + +func (pm *ProxyManager) Get(addr string) string { + pm.mux.Lock() + defer pm.mux.Unlock() + // fmt.Println(pm.Cache) + if pm.Sticky { + idx, ok := pm.Cache[addr] + if ok { + // fmt.Println("match addr", addr, "using:", idx) + return pm.Proxys[idx].URL + } + } + idx := pm.Chooser.Pick().(int) + pm.Cache[addr] = idx + // fmt.Println("addr", addr, "using:", idx) + return pm.Proxys[idx].URL +} + +func (pm *ProxyManager) ClearCache() { + // fmt.Println("clearing cache") + pm.mux.Lock() + pm.Cache = make(map[string]int) + pm.mux.Unlock() + // fmt.Println("after:", pm.Cache) +} + +func (pm *ProxyManager) keepClearingCache() { + interval := pm.CacheCleanInterval + if interval <= 0 { + return + } + for { + time.Sleep(time.Duration(interval) * time.Second) + pm.ClearCache() + } +} + +func main() { + // testFn := func() bool { + // status := proxyTestStatusCode("socks5://127.0.0.1:1080", URL, 204) + // fmt.Println(status) + // return status + // } + // stop := make(chan bool) + // go testTask(testFn, 5, stop) + // time.Sleep(10 * time.Second) + // stop <- true + // time.Sleep(10 * time.Second) + + configPath := flag.String("config", "", "Config file.") + bindAddr := flag.String("bind", "127.0.0.1:7000", "Bind address and port") + flag.Parse() + if *configPath == "" { + flag.PrintDefaults() + os.Exit(1) + } + rand.Seed(time.Now().UTC().UnixNano()) // always seed random! + + var pm *ProxyManager + pm, err := NewPM(*configPath) + if err != nil { + return + } + if pm.Sticky { + fmt.Printf("Sticky with interval %ds.\n", pm.CacheCleanInterval) + } else { + fmt.Println("Randomize every connection.") + } + conf := &Config{PM: pm} + server, err := New(conf) + if err != nil { + panic(err) + } + // Create SOCKS5 proxy on localhost port 8000 + if err := server.ListenAndServe("tcp", *bindAddr); err != nil { + panic(err) + } +} diff --git a/go-socks-lb/request.go b/go-socks-lb/request.go new file mode 100644 index 0000000..1913330 --- /dev/null +++ b/go-socks-lb/request.go @@ -0,0 +1,367 @@ +package main + +import ( + "fmt" + "io" + "net" + "net/url" + "strconv" + "strings" + + "golang.org/x/net/context" + "golang.org/x/net/proxy" +) + +const ( + ConnectCommand = uint8(1) + BindCommand = uint8(2) + AssociateCommand = uint8(3) + ipv4Address = uint8(1) + fqdnAddress = uint8(3) + ipv6Address = uint8(4) +) + +const ( + successReply uint8 = iota + serverFailure + ruleFailure + networkUnreachable + hostUnreachable + connectionRefused + ttlExpired + commandNotSupported + addrTypeNotSupported +) + +var ( + unrecognizedAddrType = fmt.Errorf("Unrecognized address type") +) + +// AddressRewriter is used to rewrite a destination transparently +type AddressRewriter interface { + Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec) +} + +// AddrSpec is used to return the target AddrSpec +// which may be specified as IPv4, IPv6, or a FQDN +type AddrSpec struct { + FQDN string + IP net.IP + Port int +} + +func (a *AddrSpec) String() string { + if a.FQDN != "" { + return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) + } + return fmt.Sprintf("%s:%d", a.IP, a.Port) +} + +// Address returns a string suitable to dial; prefer returning IP-based +// address, fallback to FQDN +func (a AddrSpec) Address() string { + if 0 != len(a.IP) { + return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) + } + return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) +} + +// A Request represents request received by a server +type Request struct { + // Protocol version + Version uint8 + // Requested command + Command uint8 + // AuthContext provided during negotiation + AuthContext *AuthContext + // AddrSpec of the the network that sent the request + RemoteAddr *AddrSpec + // AddrSpec of the desired destination + DestAddr *AddrSpec + // AddrSpec of the actual destination (might be affected by rewrite) + realDestAddr *AddrSpec + bufConn io.Reader +} + +type conn interface { + Write([]byte) (int, error) + RemoteAddr() net.Addr +} + +// NewRequest creates a new Request from the tcp connection +func NewRequest(bufConn io.Reader) (*Request, error) { + // Read the version byte + header := []byte{0, 0, 0} + if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { + return nil, fmt.Errorf("Failed to get command version: %v", err) + } + + // Ensure we are compatible + if header[0] != socks5Version { + return nil, fmt.Errorf("Unsupported command version: %v", header[0]) + } + + // Read in the destination address + dest, err := readAddrSpec(bufConn) + if err != nil { + return nil, err + } + + request := &Request{ + Version: socks5Version, + Command: header[1], + DestAddr: dest, + bufConn: bufConn, + } + + return request, nil +} + +// handleRequest is used for request processing after authentication +func (s *Server) handleRequest(req *Request, conn conn) error { + ctx := context.Background() + + // Resolve the address if we have a FQDN + // dest := req.DestAddr + // if dest.FQDN != "" { + // ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN) + // if err != nil { + // if err := sendReply(conn, hostUnreachable, nil); err != nil { + // return fmt.Errorf("Failed to send reply: %v", err) + // } + // return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err) + // } + // ctx = ctx_ + // dest.IP = addr + // } + + // Apply any address rewrites + req.realDestAddr = req.DestAddr + if s.config.Rewriter != nil { + ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req) + } + + // Switch on the command + switch req.Command { + case ConnectCommand: + return s.handleConnect(ctx, conn, req) + case BindCommand: + return s.handleBind(ctx, conn, req) + case AssociateCommand: + return s.handleAssociate(ctx, conn, req) + default: + if err := sendReply(conn, commandNotSupported, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Unsupported command: %v", req.Command) + } +} + +// handleConnect is used to handle a connect command +func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error { + // Check if this is allowed + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { + if err := sendReply(conn, ruleFailure, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr) + } else { + ctx = ctx_ + } + + // Attempt to connect + // dial := s.config.Dial + dial := func(ctx context.Context, net_, addr string) (net.Conn, error) { + var dp proxy.Dialer + u, _ := url.Parse(s.config.PM.Get(addr)) + dp, _ = proxy.FromURL(u, proxy.Direct) + return dp.Dial(net_, addr) + } + target, err := dial(ctx, "tcp", req.realDestAddr.Address()) + if err != nil { + msg := err.Error() + resp := hostUnreachable + if strings.Contains(msg, "refused") { + resp = connectionRefused + } else if strings.Contains(msg, "network is unreachable") { + resp = networkUnreachable + } + if err := sendReply(conn, resp, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err) + } + defer target.Close() + + // Send success + local := target.LocalAddr().(*net.TCPAddr) + bind := AddrSpec{IP: local.IP, Port: local.Port} + if err := sendReply(conn, successReply, &bind); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + + // Start proxying + errCh := make(chan error, 2) + go proxy_conn(target, req.bufConn, errCh) + go proxy_conn(conn, target, errCh) + + // Wait + for i := 0; i < 2; i++ { + e := <-errCh + if e != nil { + // return from this function closes target (and conn). + return e + } + } + return nil +} + +// handleBind is used to handle a connect command +func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error { + // Check if this is allowed + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { + if err := sendReply(conn, ruleFailure, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr) + } else { + ctx = ctx_ + } + + // TODO: Support bind + if err := sendReply(conn, commandNotSupported, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return nil +} + +// handleAssociate is used to handle a connect command +func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error { + // Check if this is allowed + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { + if err := sendReply(conn, ruleFailure, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr) + } else { + ctx = ctx_ + } + + // TODO: Support associate + if err := sendReply(conn, commandNotSupported, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + return nil +} + +// readAddrSpec is used to read AddrSpec. +// Expects an address type byte, follwed by the address and port +func readAddrSpec(r io.Reader) (*AddrSpec, error) { + d := &AddrSpec{} + + // Get the address type + addrType := []byte{0} + if _, err := r.Read(addrType); err != nil { + return nil, err + } + + // Handle on a per type basis + switch addrType[0] { + case ipv4Address: + addr := make([]byte, 4) + if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { + return nil, err + } + d.IP = net.IP(addr) + + case ipv6Address: + addr := make([]byte, 16) + if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { + return nil, err + } + d.IP = net.IP(addr) + + case fqdnAddress: + if _, err := r.Read(addrType); err != nil { + return nil, err + } + addrLen := int(addrType[0]) + fqdn := make([]byte, addrLen) + if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil { + return nil, err + } + d.FQDN = string(fqdn) + + default: + return nil, unrecognizedAddrType + } + + // Read the port + port := []byte{0, 0} + if _, err := io.ReadAtLeast(r, port, 2); err != nil { + return nil, err + } + d.Port = (int(port[0]) << 8) | int(port[1]) + + return d, nil +} + +// sendReply is used to send a reply message +func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error { + // Format the address + var addrType uint8 + var addrBody []byte + var addrPort uint16 + switch { + case addr == nil: + addrType = ipv4Address + addrBody = []byte{0, 0, 0, 0} + addrPort = 0 + + case addr.FQDN != "": + addrType = fqdnAddress + addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...) + addrPort = uint16(addr.Port) + + case addr.IP.To4() != nil: + addrType = ipv4Address + addrBody = []byte(addr.IP.To4()) + addrPort = uint16(addr.Port) + + case addr.IP.To16() != nil: + addrType = ipv6Address + addrBody = []byte(addr.IP.To16()) + addrPort = uint16(addr.Port) + + default: + return fmt.Errorf("Failed to format address: %v", addr) + } + + // Format the message + msg := make([]byte, 6+len(addrBody)) + msg[0] = socks5Version + msg[1] = resp + msg[2] = 0 // Reserved + msg[3] = addrType + copy(msg[4:], addrBody) + msg[4+len(addrBody)] = byte(addrPort >> 8) + msg[4+len(addrBody)+1] = byte(addrPort & 0xff) + + // Send the message + _, err := w.Write(msg) + return err +} + +type closeWriter interface { + CloseWrite() error +} + +// proxy is used to suffle data from src to destination, and sends errors +// down a dedicated channel +func proxy_conn(dst io.Writer, src io.Reader, errCh chan error) { + _, err := io.Copy(dst, src) + if tcpConn, ok := dst.(closeWriter); ok { + tcpConn.CloseWrite() + } + errCh <- err +} diff --git a/go-socks-lb/request_test.go b/go-socks-lb/request_test.go new file mode 100644 index 0000000..80ef6fb --- /dev/null +++ b/go-socks-lb/request_test.go @@ -0,0 +1,169 @@ +package main + +import ( + "bytes" + "encoding/binary" + "io" + "log" + "net" + "os" + "strings" + "testing" +) + +type MockConn struct { + buf bytes.Buffer +} + +func (m *MockConn) Write(b []byte) (int, error) { + return m.buf.Write(b) +} + +func (m *MockConn) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: []byte{127, 0, 0, 1}, Port: 65432} +} + +func TestRequest_Connect(t *testing.T) { + // Create a local listener + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + go func() { + conn, err := l.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + buf := make([]byte, 4) + if _, err := io.ReadAtLeast(conn, buf, 4); err != nil { + t.Fatalf("err: %v", err) + } + + if !bytes.Equal(buf, []byte("ping")) { + t.Fatalf("bad: %v", buf) + } + conn.Write([]byte("pong")) + }() + lAddr := l.Addr().(*net.TCPAddr) + + // Make server + s := &Server{config: &Config{ + Rules: PermitAll(), + Resolver: DNSResolver{}, + Logger: log.New(os.Stdout, "", log.LstdFlags), + }} + + // Create the connect request + buf := bytes.NewBuffer(nil) + buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) + + port := []byte{0, 0} + binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) + buf.Write(port) + + // Send a ping + buf.Write([]byte("ping")) + + // Handle the request + resp := &MockConn{} + req, err := NewRequest(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.handleRequest(req, resp); err != nil { + t.Fatalf("err: %v", err) + } + + // Verify response + out := resp.buf.Bytes() + expected := []byte{ + 5, + 0, + 0, + 1, + 127, 0, 0, 1, + 0, 0, + 'p', 'o', 'n', 'g', + } + + // Ignore the port for both + out[8] = 0 + out[9] = 0 + + if !bytes.Equal(out, expected) { + t.Fatalf("bad: %v %v", out, expected) + } +} + +func TestRequest_Connect_RuleFail(t *testing.T) { + // Create a local listener + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + go func() { + conn, err := l.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + buf := make([]byte, 4) + if _, err := io.ReadAtLeast(conn, buf, 4); err != nil { + t.Fatalf("err: %v", err) + } + + if !bytes.Equal(buf, []byte("ping")) { + t.Fatalf("bad: %v", buf) + } + conn.Write([]byte("pong")) + }() + lAddr := l.Addr().(*net.TCPAddr) + + // Make server + s := &Server{config: &Config{ + Rules: PermitNone(), + Resolver: DNSResolver{}, + Logger: log.New(os.Stdout, "", log.LstdFlags), + }} + + // Create the connect request + buf := bytes.NewBuffer(nil) + buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) + + port := []byte{0, 0} + binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) + buf.Write(port) + + // Send a ping + buf.Write([]byte("ping")) + + // Handle the request + resp := &MockConn{} + req, err := NewRequest(buf) + if err != nil { + t.Fatalf("err: %v", err) + } + + if err := s.handleRequest(req, resp); !strings.Contains(err.Error(), "blocked by rules") { + t.Fatalf("err: %v", err) + } + + // Verify response + out := resp.buf.Bytes() + expected := []byte{ + 5, + 2, + 0, + 1, + 0, 0, 0, 0, + 0, 0, + } + + if !bytes.Equal(out, expected) { + t.Fatalf("bad: %v %v", out, expected) + } +} diff --git a/go-socks-lb/resolver.go b/go-socks-lb/resolver.go new file mode 100644 index 0000000..b531dd1 --- /dev/null +++ b/go-socks-lb/resolver.go @@ -0,0 +1,23 @@ +package main + +import ( + "net" + + "golang.org/x/net/context" +) + +// NameResolver is used to implement custom name resolution +type NameResolver interface { + Resolve(ctx context.Context, name string) (context.Context, net.IP, error) +} + +// DNSResolver uses the system DNS to resolve host names +type DNSResolver struct{} + +func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) { + addr, err := net.ResolveIPAddr("ip", name) + if err != nil { + return ctx, nil, err + } + return ctx, addr.IP, err +} diff --git a/go-socks-lb/resolver_test.go b/go-socks-lb/resolver_test.go new file mode 100644 index 0000000..b017f4c --- /dev/null +++ b/go-socks-lb/resolver_test.go @@ -0,0 +1,21 @@ +package main + +import ( + "testing" + + "golang.org/x/net/context" +) + +func TestDNSResolver(t *testing.T) { + d := DNSResolver{} + ctx := context.Background() + + _, addr, err := d.Resolve(ctx, "localhost") + if err != nil { + t.Fatalf("err: %v", err) + } + + if !addr.IsLoopback() { + t.Fatalf("expected loopback") + } +} diff --git a/go-socks-lb/ruleset.go b/go-socks-lb/ruleset.go new file mode 100644 index 0000000..f84bba2 --- /dev/null +++ b/go-socks-lb/ruleset.go @@ -0,0 +1,41 @@ +package main + +import ( + "golang.org/x/net/context" +) + +// RuleSet is used to provide custom rules to allow or prohibit actions +type RuleSet interface { + Allow(ctx context.Context, req *Request) (context.Context, bool) +} + +// PermitAll returns a RuleSet which allows all types of connections +func PermitAll() RuleSet { + return &PermitCommand{true, true, true} +} + +// PermitNone returns a RuleSet which disallows all types of connections +func PermitNone() RuleSet { + return &PermitCommand{false, false, false} +} + +// PermitCommand is an implementation of the RuleSet which +// enables filtering supported commands +type PermitCommand struct { + EnableConnect bool + EnableBind bool + EnableAssociate bool +} + +func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) { + switch req.Command { + case ConnectCommand: + return ctx, p.EnableConnect + case BindCommand: + return ctx, p.EnableBind + case AssociateCommand: + return ctx, p.EnableAssociate + } + + return ctx, false +} diff --git a/go-socks-lb/ruleset_test.go b/go-socks-lb/ruleset_test.go new file mode 100644 index 0000000..d770947 --- /dev/null +++ b/go-socks-lb/ruleset_test.go @@ -0,0 +1,24 @@ +package main + +import ( + "testing" + + "golang.org/x/net/context" +) + +func TestPermitCommand(t *testing.T) { + ctx := context.Background() + r := &PermitCommand{true, false, false} + + if _, ok := r.Allow(ctx, &Request{Command: ConnectCommand}); !ok { + t.Fatalf("expect connect") + } + + if _, ok := r.Allow(ctx, &Request{Command: BindCommand}); ok { + t.Fatalf("do not expect bind") + } + + if _, ok := r.Allow(ctx, &Request{Command: AssociateCommand}); ok { + t.Fatalf("do not expect associate") + } +} diff --git a/go-socks-lb/socks5.go b/go-socks-lb/socks5.go new file mode 100644 index 0000000..0294203 --- /dev/null +++ b/go-socks-lb/socks5.go @@ -0,0 +1,172 @@ +package main + +import ( + "bufio" + "fmt" + "log" + "net" + "os" + + "golang.org/x/net/context" +) + +const ( + socks5Version = uint8(5) +) + +// Config is used to setup and configure a Server +type Config struct { + // AuthMethods can be provided to implement custom authentication + // By default, "auth-less" mode is enabled. + // For password-based auth use UserPassAuthenticator. + AuthMethods []Authenticator + + // If provided, username/password authentication is enabled, + // by appending a UserPassAuthenticator to AuthMethods. If not provided, + // and AUthMethods is nil, then "auth-less" mode is enabled. + Credentials CredentialStore + + // Resolver can be provided to do custom name resolution. + // Defaults to DNSResolver if not provided. + Resolver NameResolver + + // Rules is provided to enable custom logic around permitting + // various commands. If not provided, PermitAll is used. + Rules RuleSet + + // Rewriter can be used to transparently rewrite addresses. + // This is invoked before the RuleSet is invoked. + // Defaults to NoRewrite. + Rewriter AddressRewriter + + // BindIP is used for bind or udp associate + BindIP net.IP + + // Logger can be used to provide a custom log target. + // Defaults to stdout. + Logger *log.Logger + + // Managing proxy selections + PM *ProxyManager + + // Optional function for dialing out + Dial func(ctx context.Context, network, addr string) (net.Conn, error) +} + +// Server is reponsible for accepting connections and handling +// the details of the SOCKS5 protocol +type Server struct { + config *Config + authMethods map[uint8]Authenticator +} + +// New creates a new Server and potentially returns an error +func New(conf *Config) (*Server, error) { + // Ensure we have at least one authentication method enabled + if len(conf.AuthMethods) == 0 { + if conf.Credentials != nil { + conf.AuthMethods = []Authenticator{&UserPassAuthenticator{conf.Credentials}} + } else { + conf.AuthMethods = []Authenticator{&NoAuthAuthenticator{}} + } + } + + // Ensure we have a DNS resolver + if conf.Resolver == nil { + conf.Resolver = DNSResolver{} + } + + // Ensure we have a rule set + if conf.Rules == nil { + conf.Rules = PermitAll() + } + + // Ensure we have a log target + if conf.Logger == nil { + conf.Logger = log.New(os.Stdout, "", log.LstdFlags) + } + + server := &Server{ + config: conf, + } + + server.authMethods = make(map[uint8]Authenticator) + + for _, a := range conf.AuthMethods { + server.authMethods[a.GetCode()] = a + } + + return server, nil +} + +// ListenAndServe is used to create a listener and serve on it +func (s *Server) ListenAndServe(network, addr string) error { + l, err := net.Listen(network, addr) + if err != nil { + return err + } + return s.Serve(l) +} + +// Serve is used to serve connections from a listener +func (s *Server) Serve(l net.Listener) error { + for { + conn, err := l.Accept() + if err != nil { + return err + } + go s.ServeConn(conn) + } + return nil +} + +// ServeConn is used to serve a single connection. +func (s *Server) ServeConn(conn net.Conn) error { + defer conn.Close() + bufConn := bufio.NewReader(conn) + + // Read the version byte + version := []byte{0} + if _, err := bufConn.Read(version); err != nil { + s.config.Logger.Printf("[ERR] socks: Failed to get version byte: %v", err) + return err + } + + // Ensure we are compatible + if version[0] != socks5Version { + err := fmt.Errorf("Unsupported SOCKS version: %v", version) + s.config.Logger.Printf("[ERR] socks: %v", err) + return err + } + + // Authenticate the connection + authContext, err := s.authenticate(conn, bufConn) + if err != nil { + err = fmt.Errorf("Failed to authenticate: %v", err) + s.config.Logger.Printf("[ERR] socks: %v", err) + return err + } + + request, err := NewRequest(bufConn) + if err != nil { + if err == unrecognizedAddrType { + if err := sendReply(conn, addrTypeNotSupported, nil); err != nil { + return fmt.Errorf("Failed to send reply: %v", err) + } + } + return fmt.Errorf("Failed to read destination address: %v", err) + } + request.AuthContext = authContext + if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok { + request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port} + } + + // Process the client request + if err := s.handleRequest(request, conn); err != nil { + err = fmt.Errorf("Failed to handle request: %v", err) + s.config.Logger.Printf("[ERR] socks: %v", err) + return err + } + + return nil +} diff --git a/go-socks-lb/socks5_test.go b/go-socks-lb/socks5_test.go new file mode 100644 index 0000000..1c9b064 --- /dev/null +++ b/go-socks-lb/socks5_test.go @@ -0,0 +1,110 @@ +package main + +import ( + "bytes" + "encoding/binary" + "io" + "log" + "net" + "os" + "testing" + "time" +) + +func TestSOCKS5_Connect(t *testing.T) { + // Create a local listener + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %v", err) + } + go func() { + conn, err := l.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + defer conn.Close() + + buf := make([]byte, 4) + if _, err := io.ReadAtLeast(conn, buf, 4); err != nil { + t.Fatalf("err: %v", err) + } + + if !bytes.Equal(buf, []byte("ping")) { + t.Fatalf("bad: %v", buf) + } + conn.Write([]byte("pong")) + }() + lAddr := l.Addr().(*net.TCPAddr) + + // Create a socks server + creds := StaticCredentials{ + "foo": "bar", + } + cator := UserPassAuthenticator{Credentials: creds} + conf := &Config{ + AuthMethods: []Authenticator{cator}, + Logger: log.New(os.Stdout, "", log.LstdFlags), + } + serv, err := New(conf) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Start listening + go func() { + if err := serv.ListenAndServe("tcp", "127.0.0.1:12365"); err != nil { + t.Fatalf("err: %v", err) + } + }() + time.Sleep(10 * time.Millisecond) + + // Get a local conn + conn, err := net.Dial("tcp", "127.0.0.1:12365") + if err != nil { + t.Fatalf("err: %v", err) + } + + // Connect, auth and connec to local + req := bytes.NewBuffer(nil) + req.Write([]byte{5}) + req.Write([]byte{2, NoAuth, UserPassAuth}) + req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'}) + req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1}) + + port := []byte{0, 0} + binary.BigEndian.PutUint16(port, uint16(lAddr.Port)) + req.Write(port) + + // Send a ping + req.Write([]byte("ping")) + + // Send all the bytes + conn.Write(req.Bytes()) + + // Verify response + expected := []byte{ + socks5Version, UserPassAuth, + 1, authSuccess, + 5, + 0, + 0, + 1, + 127, 0, 0, 1, + 0, 0, + 'p', 'o', 'n', 'g', + } + out := make([]byte, len(expected)) + + conn.SetDeadline(time.Now().Add(time.Second)) + if _, err := io.ReadAtLeast(conn, out, len(out)); err != nil { + t.Fatalf("err: %v", err) + } + + // Ignore the port + out[12] = 0 + out[13] = 0 + + if !bytes.Equal(out, expected) { + t.Fatalf("bad: %v", out) + } +}