including go load balancer

This commit is contained in:
mantaohuang 2020-04-07 22:09:39 -04:00
parent ba96edbd7b
commit 99b8f3dfba
20 changed files with 1561 additions and 0 deletions

22
go-socks-lb/.gitignore vendored Normal file
View File

@ -0,0 +1,22 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe

20
go-socks-lb/LICENSE Normal file
View File

@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2014 Armon Dadgar
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

45
go-socks-lb/README.md Normal file
View File

@ -0,0 +1,45 @@
go-socks5 [![Build Status](https://travis-ci.org/armon/go-socks5.png)](https://travis-ci.org/armon/go-socks5)
=========
Provides the `socks5` package that implements a [SOCKS5 server](http://en.wikipedia.org/wiki/SOCKS).
SOCKS (Secure Sockets) is used to route traffic between a client and server through
an intermediate proxy layer. This can be used to bypass firewalls or NATs.
Feature
=======
The package has the following features:
* "No Auth" mode
* User/Password authentication
* Support for the CONNECT command
* Rules to do granular filtering of commands
* Custom DNS resolution
* Unit tests
TODO
====
The package still needs the following:
* Support for the BIND command
* Support for the ASSOCIATE command
Example
=======
Below is a simple example of usage
```go
// Create a SOCKS5 server
conf := &socks5.Config{}
server, err := socks5.New(conf)
if err != nil {
panic(err)
}
// Create SOCKS5 proxy on localhost port 8000
if err := server.ListenAndServe("tcp", "127.0.0.1:8000"); err != nil {
panic(err)
}
```

151
go-socks-lb/auth.go Normal file
View File

@ -0,0 +1,151 @@
package main
import (
"fmt"
"io"
)
const (
NoAuth = uint8(0)
noAcceptable = uint8(255)
UserPassAuth = uint8(2)
userAuthVersion = uint8(1)
authSuccess = uint8(0)
authFailure = uint8(1)
)
var (
UserAuthFailed = fmt.Errorf("User authentication failed")
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
)
// A Request encapsulates authentication state provided
// during negotiation
type AuthContext struct {
// Provided auth method
Method uint8
// Payload provided during negotiation.
// Keys depend on the used auth method.
// For UserPassauth contains Username
Payload map[string]string
}
type Authenticator interface {
Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error)
GetCode() uint8
}
// NoAuthAuthenticator is used to handle the "No Authentication" mode
type NoAuthAuthenticator struct{}
func (a NoAuthAuthenticator) GetCode() uint8 {
return NoAuth
}
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
_, err := writer.Write([]byte{socks5Version, NoAuth})
return &AuthContext{NoAuth, nil}, err
}
// UserPassAuthenticator is used to handle username/password based
// authentication
type UserPassAuthenticator struct {
Credentials CredentialStore
}
func (a UserPassAuthenticator) GetCode() uint8 {
return UserPassAuth
}
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
// Tell the client to use user/pass auth
if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil {
return nil, err
}
// Get the version and username length
header := []byte{0, 0}
if _, err := io.ReadAtLeast(reader, header, 2); err != nil {
return nil, err
}
// Ensure we are compatible
if header[0] != userAuthVersion {
return nil, fmt.Errorf("Unsupported auth version: %v", header[0])
}
// Get the user name
userLen := int(header[1])
user := make([]byte, userLen)
if _, err := io.ReadAtLeast(reader, user, userLen); err != nil {
return nil, err
}
// Get the password length
if _, err := reader.Read(header[:1]); err != nil {
return nil, err
}
// Get the password
passLen := int(header[0])
pass := make([]byte, passLen)
if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil {
return nil, err
}
// Verify the password
if a.Credentials.Valid(string(user), string(pass)) {
if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil {
return nil, err
}
} else {
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
return nil, err
}
return nil, UserAuthFailed
}
// Done
return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil
}
// authenticate is used to handle connection authentication
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) {
// Get the methods
methods, err := readMethods(bufConn)
if err != nil {
return nil, fmt.Errorf("Failed to get auth methods: %v", err)
}
// Select a usable method
for _, method := range methods {
cator, found := s.authMethods[method]
if found {
return cator.Authenticate(bufConn, conn)
}
}
// No usable method found
return nil, noAcceptableAuth(conn)
}
// noAcceptableAuth is used to handle when we have no eligible
// authentication mechanism
func noAcceptableAuth(conn io.Writer) error {
conn.Write([]byte{socks5Version, noAcceptable})
return NoSupportedAuth
}
// readMethods is used to read the number of methods
// and proceeding auth methods
func readMethods(r io.Reader) ([]byte, error) {
header := []byte{0}
if _, err := r.Read(header); err != nil {
return nil, err
}
numMethods := int(header[0])
methods := make([]byte, numMethods)
_, err := io.ReadAtLeast(r, methods, numMethods)
return methods, err
}

119
go-socks-lb/auth_test.go Normal file
View File

@ -0,0 +1,119 @@
package main
import (
"bytes"
"testing"
)
func TestNoAuth(t *testing.T) {
req := bytes.NewBuffer(nil)
req.Write([]byte{1, NoAuth})
var resp bytes.Buffer
s, _ := New(&Config{})
ctx, err := s.authenticate(&resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
if ctx.Method != NoAuth {
t.Fatal("Invalid Context Method")
}
out := resp.Bytes()
if !bytes.Equal(out, []byte{socks5Version, NoAuth}) {
t.Fatalf("bad: %v", out)
}
}
func TestPasswordAuth_Valid(t *testing.T) {
req := bytes.NewBuffer(nil)
req.Write([]byte{2, NoAuth, UserPassAuth})
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
var resp bytes.Buffer
cred := StaticCredentials{
"foo": "bar",
}
cator := UserPassAuthenticator{Credentials: cred}
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
ctx, err := s.authenticate(&resp, req)
if err != nil {
t.Fatalf("err: %v", err)
}
if ctx.Method != UserPassAuth {
t.Fatal("Invalid Context Method")
}
val, ok := ctx.Payload["Username"]
if !ok {
t.Fatal("Missing key Username in auth context's payload")
}
if val != "foo" {
t.Fatal("Invalid Username in auth context's payload")
}
out := resp.Bytes()
if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authSuccess}) {
t.Fatalf("bad: %v", out)
}
}
func TestPasswordAuth_Invalid(t *testing.T) {
req := bytes.NewBuffer(nil)
req.Write([]byte{2, NoAuth, UserPassAuth})
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'})
var resp bytes.Buffer
cred := StaticCredentials{
"foo": "bar",
}
cator := UserPassAuthenticator{Credentials: cred}
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
ctx, err := s.authenticate(&resp, req)
if err != UserAuthFailed {
t.Fatalf("err: %v", err)
}
if ctx != nil {
t.Fatal("Invalid Context Method")
}
out := resp.Bytes()
if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authFailure}) {
t.Fatalf("bad: %v", out)
}
}
func TestNoSupportedAuth(t *testing.T) {
req := bytes.NewBuffer(nil)
req.Write([]byte{1, NoAuth})
var resp bytes.Buffer
cred := StaticCredentials{
"foo": "bar",
}
cator := UserPassAuthenticator{Credentials: cred}
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
ctx, err := s.authenticate(&resp, req)
if err != NoSupportedAuth {
t.Fatalf("err: %v", err)
}
if ctx != nil {
t.Fatal("Invalid Context Method")
}
out := resp.Bytes()
if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) {
t.Fatalf("bad: %v", out)
}
}

9
go-socks-lb/config.yml Normal file
View File

@ -0,0 +1,9 @@
proxy:
- url: "socks5://192.168.122.128:1080"
weight: 5
- url: "socks5://192.168.122.128:1081"
weight: 5
- url: "socks5://192.168.122.128:1082"
weight: 5
sticky: true
cache-clean-interval: 10

View File

@ -0,0 +1,17 @@
package main
// CredentialStore is used to support user/pass authentication
type CredentialStore interface {
Valid(user, password string) bool
}
// StaticCredentials enables using a map directly as a credential store
type StaticCredentials map[string]string
func (s StaticCredentials) Valid(user, password string) bool {
pass, ok := s[user]
if !ok {
return false
}
return password == pass
}

View File

@ -0,0 +1,24 @@
package main
import (
"testing"
)
func TestStaticCredentials(t *testing.T) {
creds := StaticCredentials{
"foo": "bar",
"baz": "",
}
if !creds.Valid("foo", "bar") {
t.Fatalf("expect valid")
}
if !creds.Valid("baz", "") {
t.Fatalf("expect valid")
}
if creds.Valid("foo", "") {
t.Fatalf("expect invalid")
}
}

BIN
go-socks-lb/go-socks-lb Executable file

Binary file not shown.

9
go-socks-lb/go.mod Normal file
View File

@ -0,0 +1,9 @@
module git.hmthsn.com/mantao/mop/go-socks-lb
go 1.14
require (
github.com/mroth/weightedrand v0.2.1
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e
gopkg.in/yaml.v2 v2.2.8
)

11
go-socks-lb/go.sum Normal file
View File

@ -0,0 +1,11 @@
github.com/mroth/weightedrand v0.2.1 h1:ivJastXlhBrj0q931DJ8IwhOLGwrYtPeENWd3WlVI0s=
github.com/mroth/weightedrand v0.2.1/go.mod h1:3p2SIcC8al1YMzGhAIoXD+r9olo/g/cdJgAD905gyNE=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e h1:3G+cUijn7XD+S4eJFddp53Pv7+slrESplyjG25HgL+k=
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

207
go-socks-lb/main.go Normal file
View File

@ -0,0 +1,207 @@
package main
import (
"flag"
"fmt"
"math/rand"
"net/http"
"net/url"
"os"
"sync"
"time"
wr "github.com/mroth/weightedrand"
"golang.org/x/net/proxy"
"gopkg.in/yaml.v2"
)
const (
URL = "http://connectivitycheck.gstatic.com/generate_204"
)
func proxyTestStatusCode(proxyURL string, URL string, StatusCode int) bool {
// create a socks5 dialer
u, err := url.Parse(proxyURL)
if err != nil {
fmt.Println("error parsing")
return false
}
dialer, err := proxy.FromURL(u, proxy.Direct)
if err != nil {
fmt.Fprintln(os.Stderr, "can't connect to the proxy:", err)
return false
}
// setup a http client
httpTransport := &http.Transport{}
httpClient := &http.Client{
Transport: httpTransport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
fmt.Printf("redirect %v", *req)
return http.ErrUseLastResponse
},
}
// set our socks5 as the dialer
httpTransport.Dial = dialer.Dial
// create a request
req, err := http.NewRequest("GET", URL, nil)
if err != nil {
fmt.Fprintln(os.Stderr, "can't create request:", err)
return false
}
// use the http client to fetch the page
resp, err := httpClient.Do(req)
if err != nil {
fmt.Fprintln(os.Stderr, "can't GET page:", err)
return false
}
// defer resp.Body.Close()
// b, err := ioutil.ReadAll(resp.Body)
// if err != nil {
// fmt.Fprintln(os.Stderr, "error reading body:", err)
// return false
// }
return resp.StatusCode == StatusCode
}
func testTask(testfn func() bool, interval int, stop chan bool) {
for {
go testfn()
select {
case <-stop:
fmt.Println("stopping")
return
default:
}
time.Sleep(time.Duration(interval) * time.Second)
}
}
type proxyInst struct {
URL string
Weight int
StatusHistory []bool
}
type ProxyManager struct {
Proxys []proxyInst
Cache map[string]int
Sticky bool
CacheCleanInterval int
Chooser wr.Chooser
mux sync.Mutex
}
type proxyConf struct {
Proxys []struct {
URL string `yaml:"url"`
Weight int `yaml:"weight"`
} `yaml:"proxy"`
Sticky bool `yaml:"sticky"`
CacheCleanInterval int `yaml:"cache-clean-interval"`
}
func NewPM(confFp string) (*ProxyManager, error) {
pm := ProxyManager{}
f, err := os.Open(confFp)
if err != nil {
return &ProxyManager{}, err
}
defer f.Close()
var cfg proxyConf
decoder := yaml.NewDecoder(f)
err = decoder.Decode(&cfg)
if err != nil {
return &ProxyManager{}, err
}
var chooseArr []wr.Choice
for idx, pc := range cfg.Proxys {
var pi proxyInst
pi.URL = pc.URL
pi.Weight = pc.Weight
pm.Proxys = append(pm.Proxys, pi)
chooseArr = append(chooseArr, wr.Choice{Item: idx, Weight: uint(pi.Weight)})
}
pm.Sticky = cfg.Sticky
pm.CacheCleanInterval = cfg.CacheCleanInterval
pm.Cache = make(map[string]int)
pm.Chooser = wr.NewChooser(chooseArr...)
go pm.keepClearingCache()
return &pm, nil
}
func (pm *ProxyManager) Get(addr string) string {
pm.mux.Lock()
defer pm.mux.Unlock()
// fmt.Println(pm.Cache)
if pm.Sticky {
idx, ok := pm.Cache[addr]
if ok {
// fmt.Println("match addr", addr, "using:", idx)
return pm.Proxys[idx].URL
}
}
idx := pm.Chooser.Pick().(int)
pm.Cache[addr] = idx
// fmt.Println("addr", addr, "using:", idx)
return pm.Proxys[idx].URL
}
func (pm *ProxyManager) ClearCache() {
// fmt.Println("clearing cache")
pm.mux.Lock()
pm.Cache = make(map[string]int)
pm.mux.Unlock()
// fmt.Println("after:", pm.Cache)
}
func (pm *ProxyManager) keepClearingCache() {
interval := pm.CacheCleanInterval
if interval <= 0 {
return
}
for {
time.Sleep(time.Duration(interval) * time.Second)
pm.ClearCache()
}
}
func main() {
// testFn := func() bool {
// status := proxyTestStatusCode("socks5://127.0.0.1:1080", URL, 204)
// fmt.Println(status)
// return status
// }
// stop := make(chan bool)
// go testTask(testFn, 5, stop)
// time.Sleep(10 * time.Second)
// stop <- true
// time.Sleep(10 * time.Second)
configPath := flag.String("config", "", "Config file.")
bindAddr := flag.String("bind", "127.0.0.1:7000", "Bind address and port")
flag.Parse()
if *configPath == "" {
flag.PrintDefaults()
os.Exit(1)
}
rand.Seed(time.Now().UTC().UnixNano()) // always seed random!
var pm *ProxyManager
pm, err := NewPM(*configPath)
if err != nil {
return
}
if pm.Sticky {
fmt.Printf("Sticky with interval %ds.\n", pm.CacheCleanInterval)
} else {
fmt.Println("Randomize every connection.")
}
conf := &Config{PM: pm}
server, err := New(conf)
if err != nil {
panic(err)
}
// Create SOCKS5 proxy on localhost port 8000
if err := server.ListenAndServe("tcp", *bindAddr); err != nil {
panic(err)
}
}

367
go-socks-lb/request.go Normal file
View File

@ -0,0 +1,367 @@
package main
import (
"fmt"
"io"
"net"
"net/url"
"strconv"
"strings"
"golang.org/x/net/context"
"golang.org/x/net/proxy"
)
const (
ConnectCommand = uint8(1)
BindCommand = uint8(2)
AssociateCommand = uint8(3)
ipv4Address = uint8(1)
fqdnAddress = uint8(3)
ipv6Address = uint8(4)
)
const (
successReply uint8 = iota
serverFailure
ruleFailure
networkUnreachable
hostUnreachable
connectionRefused
ttlExpired
commandNotSupported
addrTypeNotSupported
)
var (
unrecognizedAddrType = fmt.Errorf("Unrecognized address type")
)
// AddressRewriter is used to rewrite a destination transparently
type AddressRewriter interface {
Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec)
}
// AddrSpec is used to return the target AddrSpec
// which may be specified as IPv4, IPv6, or a FQDN
type AddrSpec struct {
FQDN string
IP net.IP
Port int
}
func (a *AddrSpec) String() string {
if a.FQDN != "" {
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
}
return fmt.Sprintf("%s:%d", a.IP, a.Port)
}
// Address returns a string suitable to dial; prefer returning IP-based
// address, fallback to FQDN
func (a AddrSpec) Address() string {
if 0 != len(a.IP) {
return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port))
}
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
}
// A Request represents request received by a server
type Request struct {
// Protocol version
Version uint8
// Requested command
Command uint8
// AuthContext provided during negotiation
AuthContext *AuthContext
// AddrSpec of the the network that sent the request
RemoteAddr *AddrSpec
// AddrSpec of the desired destination
DestAddr *AddrSpec
// AddrSpec of the actual destination (might be affected by rewrite)
realDestAddr *AddrSpec
bufConn io.Reader
}
type conn interface {
Write([]byte) (int, error)
RemoteAddr() net.Addr
}
// NewRequest creates a new Request from the tcp connection
func NewRequest(bufConn io.Reader) (*Request, error) {
// Read the version byte
header := []byte{0, 0, 0}
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
return nil, fmt.Errorf("Failed to get command version: %v", err)
}
// Ensure we are compatible
if header[0] != socks5Version {
return nil, fmt.Errorf("Unsupported command version: %v", header[0])
}
// Read in the destination address
dest, err := readAddrSpec(bufConn)
if err != nil {
return nil, err
}
request := &Request{
Version: socks5Version,
Command: header[1],
DestAddr: dest,
bufConn: bufConn,
}
return request, nil
}
// handleRequest is used for request processing after authentication
func (s *Server) handleRequest(req *Request, conn conn) error {
ctx := context.Background()
// Resolve the address if we have a FQDN
// dest := req.DestAddr
// if dest.FQDN != "" {
// ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN)
// if err != nil {
// if err := sendReply(conn, hostUnreachable, nil); err != nil {
// return fmt.Errorf("Failed to send reply: %v", err)
// }
// return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err)
// }
// ctx = ctx_
// dest.IP = addr
// }
// Apply any address rewrites
req.realDestAddr = req.DestAddr
if s.config.Rewriter != nil {
ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req)
}
// Switch on the command
switch req.Command {
case ConnectCommand:
return s.handleConnect(ctx, conn, req)
case BindCommand:
return s.handleBind(ctx, conn, req)
case AssociateCommand:
return s.handleAssociate(ctx, conn, req)
default:
if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Unsupported command: %v", req.Command)
}
}
// handleConnect is used to handle a connect command
func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
}
// Attempt to connect
// dial := s.config.Dial
dial := func(ctx context.Context, net_, addr string) (net.Conn, error) {
var dp proxy.Dialer
u, _ := url.Parse(s.config.PM.Get(addr))
dp, _ = proxy.FromURL(u, proxy.Direct)
return dp.Dial(net_, addr)
}
target, err := dial(ctx, "tcp", req.realDestAddr.Address())
if err != nil {
msg := err.Error()
resp := hostUnreachable
if strings.Contains(msg, "refused") {
resp = connectionRefused
} else if strings.Contains(msg, "network is unreachable") {
resp = networkUnreachable
}
if err := sendReply(conn, resp, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err)
}
defer target.Close()
// Send success
local := target.LocalAddr().(*net.TCPAddr)
bind := AddrSpec{IP: local.IP, Port: local.Port}
if err := sendReply(conn, successReply, &bind); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
// Start proxying
errCh := make(chan error, 2)
go proxy_conn(target, req.bufConn, errCh)
go proxy_conn(conn, target, errCh)
// Wait
for i := 0; i < 2; i++ {
e := <-errCh
if e != nil {
// return from this function closes target (and conn).
return e
}
}
return nil
}
// handleBind is used to handle a connect command
func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
}
// TODO: Support bind
if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return nil
}
// handleAssociate is used to handle a connect command
func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
}
// TODO: Support associate
if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return nil
}
// readAddrSpec is used to read AddrSpec.
// Expects an address type byte, follwed by the address and port
func readAddrSpec(r io.Reader) (*AddrSpec, error) {
d := &AddrSpec{}
// Get the address type
addrType := []byte{0}
if _, err := r.Read(addrType); err != nil {
return nil, err
}
// Handle on a per type basis
switch addrType[0] {
case ipv4Address:
addr := make([]byte, 4)
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
return nil, err
}
d.IP = net.IP(addr)
case ipv6Address:
addr := make([]byte, 16)
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
return nil, err
}
d.IP = net.IP(addr)
case fqdnAddress:
if _, err := r.Read(addrType); err != nil {
return nil, err
}
addrLen := int(addrType[0])
fqdn := make([]byte, addrLen)
if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil {
return nil, err
}
d.FQDN = string(fqdn)
default:
return nil, unrecognizedAddrType
}
// Read the port
port := []byte{0, 0}
if _, err := io.ReadAtLeast(r, port, 2); err != nil {
return nil, err
}
d.Port = (int(port[0]) << 8) | int(port[1])
return d, nil
}
// sendReply is used to send a reply message
func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
// Format the address
var addrType uint8
var addrBody []byte
var addrPort uint16
switch {
case addr == nil:
addrType = ipv4Address
addrBody = []byte{0, 0, 0, 0}
addrPort = 0
case addr.FQDN != "":
addrType = fqdnAddress
addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...)
addrPort = uint16(addr.Port)
case addr.IP.To4() != nil:
addrType = ipv4Address
addrBody = []byte(addr.IP.To4())
addrPort = uint16(addr.Port)
case addr.IP.To16() != nil:
addrType = ipv6Address
addrBody = []byte(addr.IP.To16())
addrPort = uint16(addr.Port)
default:
return fmt.Errorf("Failed to format address: %v", addr)
}
// Format the message
msg := make([]byte, 6+len(addrBody))
msg[0] = socks5Version
msg[1] = resp
msg[2] = 0 // Reserved
msg[3] = addrType
copy(msg[4:], addrBody)
msg[4+len(addrBody)] = byte(addrPort >> 8)
msg[4+len(addrBody)+1] = byte(addrPort & 0xff)
// Send the message
_, err := w.Write(msg)
return err
}
type closeWriter interface {
CloseWrite() error
}
// proxy is used to suffle data from src to destination, and sends errors
// down a dedicated channel
func proxy_conn(dst io.Writer, src io.Reader, errCh chan error) {
_, err := io.Copy(dst, src)
if tcpConn, ok := dst.(closeWriter); ok {
tcpConn.CloseWrite()
}
errCh <- err
}

169
go-socks-lb/request_test.go Normal file
View File

@ -0,0 +1,169 @@
package main
import (
"bytes"
"encoding/binary"
"io"
"log"
"net"
"os"
"strings"
"testing"
)
type MockConn struct {
buf bytes.Buffer
}
func (m *MockConn) Write(b []byte) (int, error) {
return m.buf.Write(b)
}
func (m *MockConn) RemoteAddr() net.Addr {
return &net.TCPAddr{IP: []byte{127, 0, 0, 1}, Port: 65432}
}
func TestRequest_Connect(t *testing.T) {
// Create a local listener
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
go func() {
conn, err := l.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
buf := make([]byte, 4)
if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(buf, []byte("ping")) {
t.Fatalf("bad: %v", buf)
}
conn.Write([]byte("pong"))
}()
lAddr := l.Addr().(*net.TCPAddr)
// Make server
s := &Server{config: &Config{
Rules: PermitAll(),
Resolver: DNSResolver{},
Logger: log.New(os.Stdout, "", log.LstdFlags),
}}
// Create the connect request
buf := bytes.NewBuffer(nil)
buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
port := []byte{0, 0}
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
buf.Write(port)
// Send a ping
buf.Write([]byte("ping"))
// Handle the request
resp := &MockConn{}
req, err := NewRequest(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if err := s.handleRequest(req, resp); err != nil {
t.Fatalf("err: %v", err)
}
// Verify response
out := resp.buf.Bytes()
expected := []byte{
5,
0,
0,
1,
127, 0, 0, 1,
0, 0,
'p', 'o', 'n', 'g',
}
// Ignore the port for both
out[8] = 0
out[9] = 0
if !bytes.Equal(out, expected) {
t.Fatalf("bad: %v %v", out, expected)
}
}
func TestRequest_Connect_RuleFail(t *testing.T) {
// Create a local listener
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
go func() {
conn, err := l.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
buf := make([]byte, 4)
if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(buf, []byte("ping")) {
t.Fatalf("bad: %v", buf)
}
conn.Write([]byte("pong"))
}()
lAddr := l.Addr().(*net.TCPAddr)
// Make server
s := &Server{config: &Config{
Rules: PermitNone(),
Resolver: DNSResolver{},
Logger: log.New(os.Stdout, "", log.LstdFlags),
}}
// Create the connect request
buf := bytes.NewBuffer(nil)
buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
port := []byte{0, 0}
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
buf.Write(port)
// Send a ping
buf.Write([]byte("ping"))
// Handle the request
resp := &MockConn{}
req, err := NewRequest(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if err := s.handleRequest(req, resp); !strings.Contains(err.Error(), "blocked by rules") {
t.Fatalf("err: %v", err)
}
// Verify response
out := resp.buf.Bytes()
expected := []byte{
5,
2,
0,
1,
0, 0, 0, 0,
0, 0,
}
if !bytes.Equal(out, expected) {
t.Fatalf("bad: %v %v", out, expected)
}
}

23
go-socks-lb/resolver.go Normal file
View File

@ -0,0 +1,23 @@
package main
import (
"net"
"golang.org/x/net/context"
)
// NameResolver is used to implement custom name resolution
type NameResolver interface {
Resolve(ctx context.Context, name string) (context.Context, net.IP, error)
}
// DNSResolver uses the system DNS to resolve host names
type DNSResolver struct{}
func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
addr, err := net.ResolveIPAddr("ip", name)
if err != nil {
return ctx, nil, err
}
return ctx, addr.IP, err
}

View File

@ -0,0 +1,21 @@
package main
import (
"testing"
"golang.org/x/net/context"
)
func TestDNSResolver(t *testing.T) {
d := DNSResolver{}
ctx := context.Background()
_, addr, err := d.Resolve(ctx, "localhost")
if err != nil {
t.Fatalf("err: %v", err)
}
if !addr.IsLoopback() {
t.Fatalf("expected loopback")
}
}

41
go-socks-lb/ruleset.go Normal file
View File

@ -0,0 +1,41 @@
package main
import (
"golang.org/x/net/context"
)
// RuleSet is used to provide custom rules to allow or prohibit actions
type RuleSet interface {
Allow(ctx context.Context, req *Request) (context.Context, bool)
}
// PermitAll returns a RuleSet which allows all types of connections
func PermitAll() RuleSet {
return &PermitCommand{true, true, true}
}
// PermitNone returns a RuleSet which disallows all types of connections
func PermitNone() RuleSet {
return &PermitCommand{false, false, false}
}
// PermitCommand is an implementation of the RuleSet which
// enables filtering supported commands
type PermitCommand struct {
EnableConnect bool
EnableBind bool
EnableAssociate bool
}
func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) {
switch req.Command {
case ConnectCommand:
return ctx, p.EnableConnect
case BindCommand:
return ctx, p.EnableBind
case AssociateCommand:
return ctx, p.EnableAssociate
}
return ctx, false
}

View File

@ -0,0 +1,24 @@
package main
import (
"testing"
"golang.org/x/net/context"
)
func TestPermitCommand(t *testing.T) {
ctx := context.Background()
r := &PermitCommand{true, false, false}
if _, ok := r.Allow(ctx, &Request{Command: ConnectCommand}); !ok {
t.Fatalf("expect connect")
}
if _, ok := r.Allow(ctx, &Request{Command: BindCommand}); ok {
t.Fatalf("do not expect bind")
}
if _, ok := r.Allow(ctx, &Request{Command: AssociateCommand}); ok {
t.Fatalf("do not expect associate")
}
}

172
go-socks-lb/socks5.go Normal file
View File

@ -0,0 +1,172 @@
package main
import (
"bufio"
"fmt"
"log"
"net"
"os"
"golang.org/x/net/context"
)
const (
socks5Version = uint8(5)
)
// Config is used to setup and configure a Server
type Config struct {
// AuthMethods can be provided to implement custom authentication
// By default, "auth-less" mode is enabled.
// For password-based auth use UserPassAuthenticator.
AuthMethods []Authenticator
// If provided, username/password authentication is enabled,
// by appending a UserPassAuthenticator to AuthMethods. If not provided,
// and AUthMethods is nil, then "auth-less" mode is enabled.
Credentials CredentialStore
// Resolver can be provided to do custom name resolution.
// Defaults to DNSResolver if not provided.
Resolver NameResolver
// Rules is provided to enable custom logic around permitting
// various commands. If not provided, PermitAll is used.
Rules RuleSet
// Rewriter can be used to transparently rewrite addresses.
// This is invoked before the RuleSet is invoked.
// Defaults to NoRewrite.
Rewriter AddressRewriter
// BindIP is used for bind or udp associate
BindIP net.IP
// Logger can be used to provide a custom log target.
// Defaults to stdout.
Logger *log.Logger
// Managing proxy selections
PM *ProxyManager
// Optional function for dialing out
Dial func(ctx context.Context, network, addr string) (net.Conn, error)
}
// Server is reponsible for accepting connections and handling
// the details of the SOCKS5 protocol
type Server struct {
config *Config
authMethods map[uint8]Authenticator
}
// New creates a new Server and potentially returns an error
func New(conf *Config) (*Server, error) {
// Ensure we have at least one authentication method enabled
if len(conf.AuthMethods) == 0 {
if conf.Credentials != nil {
conf.AuthMethods = []Authenticator{&UserPassAuthenticator{conf.Credentials}}
} else {
conf.AuthMethods = []Authenticator{&NoAuthAuthenticator{}}
}
}
// Ensure we have a DNS resolver
if conf.Resolver == nil {
conf.Resolver = DNSResolver{}
}
// Ensure we have a rule set
if conf.Rules == nil {
conf.Rules = PermitAll()
}
// Ensure we have a log target
if conf.Logger == nil {
conf.Logger = log.New(os.Stdout, "", log.LstdFlags)
}
server := &Server{
config: conf,
}
server.authMethods = make(map[uint8]Authenticator)
for _, a := range conf.AuthMethods {
server.authMethods[a.GetCode()] = a
}
return server, nil
}
// ListenAndServe is used to create a listener and serve on it
func (s *Server) ListenAndServe(network, addr string) error {
l, err := net.Listen(network, addr)
if err != nil {
return err
}
return s.Serve(l)
}
// Serve is used to serve connections from a listener
func (s *Server) Serve(l net.Listener) error {
for {
conn, err := l.Accept()
if err != nil {
return err
}
go s.ServeConn(conn)
}
return nil
}
// ServeConn is used to serve a single connection.
func (s *Server) ServeConn(conn net.Conn) error {
defer conn.Close()
bufConn := bufio.NewReader(conn)
// Read the version byte
version := []byte{0}
if _, err := bufConn.Read(version); err != nil {
s.config.Logger.Printf("[ERR] socks: Failed to get version byte: %v", err)
return err
}
// Ensure we are compatible
if version[0] != socks5Version {
err := fmt.Errorf("Unsupported SOCKS version: %v", version)
s.config.Logger.Printf("[ERR] socks: %v", err)
return err
}
// Authenticate the connection
authContext, err := s.authenticate(conn, bufConn)
if err != nil {
err = fmt.Errorf("Failed to authenticate: %v", err)
s.config.Logger.Printf("[ERR] socks: %v", err)
return err
}
request, err := NewRequest(bufConn)
if err != nil {
if err == unrecognizedAddrType {
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
}
return fmt.Errorf("Failed to read destination address: %v", err)
}
request.AuthContext = authContext
if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok {
request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port}
}
// Process the client request
if err := s.handleRequest(request, conn); err != nil {
err = fmt.Errorf("Failed to handle request: %v", err)
s.config.Logger.Printf("[ERR] socks: %v", err)
return err
}
return nil
}

110
go-socks-lb/socks5_test.go Normal file
View File

@ -0,0 +1,110 @@
package main
import (
"bytes"
"encoding/binary"
"io"
"log"
"net"
"os"
"testing"
"time"
)
func TestSOCKS5_Connect(t *testing.T) {
// Create a local listener
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}
go func() {
conn, err := l.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()
buf := make([]byte, 4)
if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
t.Fatalf("err: %v", err)
}
if !bytes.Equal(buf, []byte("ping")) {
t.Fatalf("bad: %v", buf)
}
conn.Write([]byte("pong"))
}()
lAddr := l.Addr().(*net.TCPAddr)
// Create a socks server
creds := StaticCredentials{
"foo": "bar",
}
cator := UserPassAuthenticator{Credentials: creds}
conf := &Config{
AuthMethods: []Authenticator{cator},
Logger: log.New(os.Stdout, "", log.LstdFlags),
}
serv, err := New(conf)
if err != nil {
t.Fatalf("err: %v", err)
}
// Start listening
go func() {
if err := serv.ListenAndServe("tcp", "127.0.0.1:12365"); err != nil {
t.Fatalf("err: %v", err)
}
}()
time.Sleep(10 * time.Millisecond)
// Get a local conn
conn, err := net.Dial("tcp", "127.0.0.1:12365")
if err != nil {
t.Fatalf("err: %v", err)
}
// Connect, auth and connec to local
req := bytes.NewBuffer(nil)
req.Write([]byte{5})
req.Write([]byte{2, NoAuth, UserPassAuth})
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
port := []byte{0, 0}
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
req.Write(port)
// Send a ping
req.Write([]byte("ping"))
// Send all the bytes
conn.Write(req.Bytes())
// Verify response
expected := []byte{
socks5Version, UserPassAuth,
1, authSuccess,
5,
0,
0,
1,
127, 0, 0, 1,
0, 0,
'p', 'o', 'n', 'g',
}
out := make([]byte, len(expected))
conn.SetDeadline(time.Now().Add(time.Second))
if _, err := io.ReadAtLeast(conn, out, len(out)); err != nil {
t.Fatalf("err: %v", err)
}
// Ignore the port
out[12] = 0
out[13] = 0
if !bytes.Equal(out, expected) {
t.Fatalf("bad: %v", out)
}
}