including go load balancer
This commit is contained in:
parent
ba96edbd7b
commit
99b8f3dfba
22
go-socks-lb/.gitignore
vendored
Normal file
22
go-socks-lb/.gitignore
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||
*.o
|
||||
*.a
|
||||
*.so
|
||||
|
||||
# Folders
|
||||
_obj
|
||||
_test
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
[568vq].out
|
||||
|
||||
*.cgo1.go
|
||||
*.cgo2.c
|
||||
_cgo_defun.c
|
||||
_cgo_gotypes.go
|
||||
_cgo_export.*
|
||||
|
||||
_testmain.go
|
||||
|
||||
*.exe
|
||||
20
go-socks-lb/LICENSE
Normal file
20
go-socks-lb/LICENSE
Normal file
@ -0,0 +1,20 @@
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014 Armon Dadgar
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
the Software without restriction, including without limitation the rights to
|
||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||
the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
45
go-socks-lb/README.md
Normal file
45
go-socks-lb/README.md
Normal file
@ -0,0 +1,45 @@
|
||||
go-socks5 [](https://travis-ci.org/armon/go-socks5)
|
||||
=========
|
||||
|
||||
Provides the `socks5` package that implements a [SOCKS5 server](http://en.wikipedia.org/wiki/SOCKS).
|
||||
SOCKS (Secure Sockets) is used to route traffic between a client and server through
|
||||
an intermediate proxy layer. This can be used to bypass firewalls or NATs.
|
||||
|
||||
Feature
|
||||
=======
|
||||
|
||||
The package has the following features:
|
||||
* "No Auth" mode
|
||||
* User/Password authentication
|
||||
* Support for the CONNECT command
|
||||
* Rules to do granular filtering of commands
|
||||
* Custom DNS resolution
|
||||
* Unit tests
|
||||
|
||||
TODO
|
||||
====
|
||||
|
||||
The package still needs the following:
|
||||
* Support for the BIND command
|
||||
* Support for the ASSOCIATE command
|
||||
|
||||
|
||||
Example
|
||||
=======
|
||||
|
||||
Below is a simple example of usage
|
||||
|
||||
```go
|
||||
// Create a SOCKS5 server
|
||||
conf := &socks5.Config{}
|
||||
server, err := socks5.New(conf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Create SOCKS5 proxy on localhost port 8000
|
||||
if err := server.ListenAndServe("tcp", "127.0.0.1:8000"); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
```
|
||||
|
||||
151
go-socks-lb/auth.go
Normal file
151
go-socks-lb/auth.go
Normal file
@ -0,0 +1,151 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
NoAuth = uint8(0)
|
||||
noAcceptable = uint8(255)
|
||||
UserPassAuth = uint8(2)
|
||||
userAuthVersion = uint8(1)
|
||||
authSuccess = uint8(0)
|
||||
authFailure = uint8(1)
|
||||
)
|
||||
|
||||
var (
|
||||
UserAuthFailed = fmt.Errorf("User authentication failed")
|
||||
NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
|
||||
)
|
||||
|
||||
// A Request encapsulates authentication state provided
|
||||
// during negotiation
|
||||
type AuthContext struct {
|
||||
// Provided auth method
|
||||
Method uint8
|
||||
// Payload provided during negotiation.
|
||||
// Keys depend on the used auth method.
|
||||
// For UserPassauth contains Username
|
||||
Payload map[string]string
|
||||
}
|
||||
|
||||
type Authenticator interface {
|
||||
Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error)
|
||||
GetCode() uint8
|
||||
}
|
||||
|
||||
// NoAuthAuthenticator is used to handle the "No Authentication" mode
|
||||
type NoAuthAuthenticator struct{}
|
||||
|
||||
func (a NoAuthAuthenticator) GetCode() uint8 {
|
||||
return NoAuth
|
||||
}
|
||||
|
||||
func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
|
||||
_, err := writer.Write([]byte{socks5Version, NoAuth})
|
||||
return &AuthContext{NoAuth, nil}, err
|
||||
}
|
||||
|
||||
// UserPassAuthenticator is used to handle username/password based
|
||||
// authentication
|
||||
type UserPassAuthenticator struct {
|
||||
Credentials CredentialStore
|
||||
}
|
||||
|
||||
func (a UserPassAuthenticator) GetCode() uint8 {
|
||||
return UserPassAuth
|
||||
}
|
||||
|
||||
func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
|
||||
// Tell the client to use user/pass auth
|
||||
if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the version and username length
|
||||
header := []byte{0, 0}
|
||||
if _, err := io.ReadAtLeast(reader, header, 2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure we are compatible
|
||||
if header[0] != userAuthVersion {
|
||||
return nil, fmt.Errorf("Unsupported auth version: %v", header[0])
|
||||
}
|
||||
|
||||
// Get the user name
|
||||
userLen := int(header[1])
|
||||
user := make([]byte, userLen)
|
||||
if _, err := io.ReadAtLeast(reader, user, userLen); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the password length
|
||||
if _, err := reader.Read(header[:1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the password
|
||||
passLen := int(header[0])
|
||||
pass := make([]byte, passLen)
|
||||
if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Verify the password
|
||||
if a.Credentials.Valid(string(user), string(pass)) {
|
||||
if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, UserAuthFailed
|
||||
}
|
||||
|
||||
// Done
|
||||
return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil
|
||||
}
|
||||
|
||||
// authenticate is used to handle connection authentication
|
||||
func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) {
|
||||
// Get the methods
|
||||
methods, err := readMethods(bufConn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to get auth methods: %v", err)
|
||||
}
|
||||
|
||||
// Select a usable method
|
||||
for _, method := range methods {
|
||||
cator, found := s.authMethods[method]
|
||||
if found {
|
||||
return cator.Authenticate(bufConn, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// No usable method found
|
||||
return nil, noAcceptableAuth(conn)
|
||||
}
|
||||
|
||||
// noAcceptableAuth is used to handle when we have no eligible
|
||||
// authentication mechanism
|
||||
func noAcceptableAuth(conn io.Writer) error {
|
||||
conn.Write([]byte{socks5Version, noAcceptable})
|
||||
return NoSupportedAuth
|
||||
}
|
||||
|
||||
// readMethods is used to read the number of methods
|
||||
// and proceeding auth methods
|
||||
func readMethods(r io.Reader) ([]byte, error) {
|
||||
header := []byte{0}
|
||||
if _, err := r.Read(header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
numMethods := int(header[0])
|
||||
methods := make([]byte, numMethods)
|
||||
_, err := io.ReadAtLeast(r, methods, numMethods)
|
||||
return methods, err
|
||||
}
|
||||
119
go-socks-lb/auth_test.go
Normal file
119
go-socks-lb/auth_test.go
Normal file
@ -0,0 +1,119 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNoAuth(t *testing.T) {
|
||||
req := bytes.NewBuffer(nil)
|
||||
req.Write([]byte{1, NoAuth})
|
||||
var resp bytes.Buffer
|
||||
|
||||
s, _ := New(&Config{})
|
||||
ctx, err := s.authenticate(&resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if ctx.Method != NoAuth {
|
||||
t.Fatal("Invalid Context Method")
|
||||
}
|
||||
|
||||
out := resp.Bytes()
|
||||
if !bytes.Equal(out, []byte{socks5Version, NoAuth}) {
|
||||
t.Fatalf("bad: %v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordAuth_Valid(t *testing.T) {
|
||||
req := bytes.NewBuffer(nil)
|
||||
req.Write([]byte{2, NoAuth, UserPassAuth})
|
||||
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
|
||||
var resp bytes.Buffer
|
||||
|
||||
cred := StaticCredentials{
|
||||
"foo": "bar",
|
||||
}
|
||||
|
||||
cator := UserPassAuthenticator{Credentials: cred}
|
||||
|
||||
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
|
||||
|
||||
ctx, err := s.authenticate(&resp, req)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if ctx.Method != UserPassAuth {
|
||||
t.Fatal("Invalid Context Method")
|
||||
}
|
||||
|
||||
val, ok := ctx.Payload["Username"]
|
||||
if !ok {
|
||||
t.Fatal("Missing key Username in auth context's payload")
|
||||
}
|
||||
|
||||
if val != "foo" {
|
||||
t.Fatal("Invalid Username in auth context's payload")
|
||||
}
|
||||
|
||||
out := resp.Bytes()
|
||||
if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authSuccess}) {
|
||||
t.Fatalf("bad: %v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordAuth_Invalid(t *testing.T) {
|
||||
req := bytes.NewBuffer(nil)
|
||||
req.Write([]byte{2, NoAuth, UserPassAuth})
|
||||
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'z'})
|
||||
var resp bytes.Buffer
|
||||
|
||||
cred := StaticCredentials{
|
||||
"foo": "bar",
|
||||
}
|
||||
cator := UserPassAuthenticator{Credentials: cred}
|
||||
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
|
||||
|
||||
ctx, err := s.authenticate(&resp, req)
|
||||
if err != UserAuthFailed {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if ctx != nil {
|
||||
t.Fatal("Invalid Context Method")
|
||||
}
|
||||
|
||||
out := resp.Bytes()
|
||||
if !bytes.Equal(out, []byte{socks5Version, UserPassAuth, 1, authFailure}) {
|
||||
t.Fatalf("bad: %v", out)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoSupportedAuth(t *testing.T) {
|
||||
req := bytes.NewBuffer(nil)
|
||||
req.Write([]byte{1, NoAuth})
|
||||
var resp bytes.Buffer
|
||||
|
||||
cred := StaticCredentials{
|
||||
"foo": "bar",
|
||||
}
|
||||
cator := UserPassAuthenticator{Credentials: cred}
|
||||
|
||||
s, _ := New(&Config{AuthMethods: []Authenticator{cator}})
|
||||
|
||||
ctx, err := s.authenticate(&resp, req)
|
||||
if err != NoSupportedAuth {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if ctx != nil {
|
||||
t.Fatal("Invalid Context Method")
|
||||
}
|
||||
|
||||
out := resp.Bytes()
|
||||
if !bytes.Equal(out, []byte{socks5Version, noAcceptable}) {
|
||||
t.Fatalf("bad: %v", out)
|
||||
}
|
||||
}
|
||||
9
go-socks-lb/config.yml
Normal file
9
go-socks-lb/config.yml
Normal file
@ -0,0 +1,9 @@
|
||||
proxy:
|
||||
- url: "socks5://192.168.122.128:1080"
|
||||
weight: 5
|
||||
- url: "socks5://192.168.122.128:1081"
|
||||
weight: 5
|
||||
- url: "socks5://192.168.122.128:1082"
|
||||
weight: 5
|
||||
sticky: true
|
||||
cache-clean-interval: 10
|
||||
17
go-socks-lb/credentials.go
Normal file
17
go-socks-lb/credentials.go
Normal file
@ -0,0 +1,17 @@
|
||||
package main
|
||||
|
||||
// CredentialStore is used to support user/pass authentication
|
||||
type CredentialStore interface {
|
||||
Valid(user, password string) bool
|
||||
}
|
||||
|
||||
// StaticCredentials enables using a map directly as a credential store
|
||||
type StaticCredentials map[string]string
|
||||
|
||||
func (s StaticCredentials) Valid(user, password string) bool {
|
||||
pass, ok := s[user]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return password == pass
|
||||
}
|
||||
24
go-socks-lb/credentials_test.go
Normal file
24
go-socks-lb/credentials_test.go
Normal file
@ -0,0 +1,24 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestStaticCredentials(t *testing.T) {
|
||||
creds := StaticCredentials{
|
||||
"foo": "bar",
|
||||
"baz": "",
|
||||
}
|
||||
|
||||
if !creds.Valid("foo", "bar") {
|
||||
t.Fatalf("expect valid")
|
||||
}
|
||||
|
||||
if !creds.Valid("baz", "") {
|
||||
t.Fatalf("expect valid")
|
||||
}
|
||||
|
||||
if creds.Valid("foo", "") {
|
||||
t.Fatalf("expect invalid")
|
||||
}
|
||||
}
|
||||
BIN
go-socks-lb/go-socks-lb
Executable file
BIN
go-socks-lb/go-socks-lb
Executable file
Binary file not shown.
9
go-socks-lb/go.mod
Normal file
9
go-socks-lb/go.mod
Normal file
@ -0,0 +1,9 @@
|
||||
module git.hmthsn.com/mantao/mop/go-socks-lb
|
||||
|
||||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/mroth/weightedrand v0.2.1
|
||||
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e
|
||||
gopkg.in/yaml.v2 v2.2.8
|
||||
)
|
||||
11
go-socks-lb/go.sum
Normal file
11
go-socks-lb/go.sum
Normal file
@ -0,0 +1,11 @@
|
||||
github.com/mroth/weightedrand v0.2.1 h1:ivJastXlhBrj0q931DJ8IwhOLGwrYtPeENWd3WlVI0s=
|
||||
github.com/mroth/weightedrand v0.2.1/go.mod h1:3p2SIcC8al1YMzGhAIoXD+r9olo/g/cdJgAD905gyNE=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e h1:3G+cUijn7XD+S4eJFddp53Pv7+slrESplyjG25HgL+k=
|
||||
golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
207
go-socks-lb/main.go
Normal file
207
go-socks-lb/main.go
Normal file
@ -0,0 +1,207 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
wr "github.com/mroth/weightedrand"
|
||||
"golang.org/x/net/proxy"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
URL = "http://connectivitycheck.gstatic.com/generate_204"
|
||||
)
|
||||
|
||||
func proxyTestStatusCode(proxyURL string, URL string, StatusCode int) bool {
|
||||
// create a socks5 dialer
|
||||
u, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
fmt.Println("error parsing")
|
||||
return false
|
||||
}
|
||||
dialer, err := proxy.FromURL(u, proxy.Direct)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "can't connect to the proxy:", err)
|
||||
return false
|
||||
}
|
||||
// setup a http client
|
||||
httpTransport := &http.Transport{}
|
||||
httpClient := &http.Client{
|
||||
Transport: httpTransport,
|
||||
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||
fmt.Printf("redirect %v", *req)
|
||||
return http.ErrUseLastResponse
|
||||
},
|
||||
}
|
||||
// set our socks5 as the dialer
|
||||
httpTransport.Dial = dialer.Dial
|
||||
// create a request
|
||||
req, err := http.NewRequest("GET", URL, nil)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "can't create request:", err)
|
||||
return false
|
||||
}
|
||||
// use the http client to fetch the page
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, "can't GET page:", err)
|
||||
return false
|
||||
}
|
||||
// defer resp.Body.Close()
|
||||
// b, err := ioutil.ReadAll(resp.Body)
|
||||
// if err != nil {
|
||||
// fmt.Fprintln(os.Stderr, "error reading body:", err)
|
||||
// return false
|
||||
// }
|
||||
return resp.StatusCode == StatusCode
|
||||
}
|
||||
|
||||
func testTask(testfn func() bool, interval int, stop chan bool) {
|
||||
for {
|
||||
go testfn()
|
||||
select {
|
||||
case <-stop:
|
||||
fmt.Println("stopping")
|
||||
return
|
||||
default:
|
||||
}
|
||||
time.Sleep(time.Duration(interval) * time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
type proxyInst struct {
|
||||
URL string
|
||||
Weight int
|
||||
StatusHistory []bool
|
||||
}
|
||||
type ProxyManager struct {
|
||||
Proxys []proxyInst
|
||||
Cache map[string]int
|
||||
Sticky bool
|
||||
CacheCleanInterval int
|
||||
Chooser wr.Chooser
|
||||
mux sync.Mutex
|
||||
}
|
||||
type proxyConf struct {
|
||||
Proxys []struct {
|
||||
URL string `yaml:"url"`
|
||||
Weight int `yaml:"weight"`
|
||||
} `yaml:"proxy"`
|
||||
Sticky bool `yaml:"sticky"`
|
||||
CacheCleanInterval int `yaml:"cache-clean-interval"`
|
||||
}
|
||||
|
||||
func NewPM(confFp string) (*ProxyManager, error) {
|
||||
pm := ProxyManager{}
|
||||
f, err := os.Open(confFp)
|
||||
if err != nil {
|
||||
return &ProxyManager{}, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var cfg proxyConf
|
||||
decoder := yaml.NewDecoder(f)
|
||||
err = decoder.Decode(&cfg)
|
||||
if err != nil {
|
||||
return &ProxyManager{}, err
|
||||
}
|
||||
var chooseArr []wr.Choice
|
||||
for idx, pc := range cfg.Proxys {
|
||||
var pi proxyInst
|
||||
pi.URL = pc.URL
|
||||
pi.Weight = pc.Weight
|
||||
pm.Proxys = append(pm.Proxys, pi)
|
||||
chooseArr = append(chooseArr, wr.Choice{Item: idx, Weight: uint(pi.Weight)})
|
||||
}
|
||||
pm.Sticky = cfg.Sticky
|
||||
pm.CacheCleanInterval = cfg.CacheCleanInterval
|
||||
pm.Cache = make(map[string]int)
|
||||
pm.Chooser = wr.NewChooser(chooseArr...)
|
||||
go pm.keepClearingCache()
|
||||
return &pm, nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) Get(addr string) string {
|
||||
pm.mux.Lock()
|
||||
defer pm.mux.Unlock()
|
||||
// fmt.Println(pm.Cache)
|
||||
if pm.Sticky {
|
||||
idx, ok := pm.Cache[addr]
|
||||
if ok {
|
||||
// fmt.Println("match addr", addr, "using:", idx)
|
||||
return pm.Proxys[idx].URL
|
||||
}
|
||||
}
|
||||
idx := pm.Chooser.Pick().(int)
|
||||
pm.Cache[addr] = idx
|
||||
// fmt.Println("addr", addr, "using:", idx)
|
||||
return pm.Proxys[idx].URL
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) ClearCache() {
|
||||
// fmt.Println("clearing cache")
|
||||
pm.mux.Lock()
|
||||
pm.Cache = make(map[string]int)
|
||||
pm.mux.Unlock()
|
||||
// fmt.Println("after:", pm.Cache)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) keepClearingCache() {
|
||||
interval := pm.CacheCleanInterval
|
||||
if interval <= 0 {
|
||||
return
|
||||
}
|
||||
for {
|
||||
time.Sleep(time.Duration(interval) * time.Second)
|
||||
pm.ClearCache()
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
// testFn := func() bool {
|
||||
// status := proxyTestStatusCode("socks5://127.0.0.1:1080", URL, 204)
|
||||
// fmt.Println(status)
|
||||
// return status
|
||||
// }
|
||||
// stop := make(chan bool)
|
||||
// go testTask(testFn, 5, stop)
|
||||
// time.Sleep(10 * time.Second)
|
||||
// stop <- true
|
||||
// time.Sleep(10 * time.Second)
|
||||
|
||||
configPath := flag.String("config", "", "Config file.")
|
||||
bindAddr := flag.String("bind", "127.0.0.1:7000", "Bind address and port")
|
||||
flag.Parse()
|
||||
if *configPath == "" {
|
||||
flag.PrintDefaults()
|
||||
os.Exit(1)
|
||||
}
|
||||
rand.Seed(time.Now().UTC().UnixNano()) // always seed random!
|
||||
|
||||
var pm *ProxyManager
|
||||
pm, err := NewPM(*configPath)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if pm.Sticky {
|
||||
fmt.Printf("Sticky with interval %ds.\n", pm.CacheCleanInterval)
|
||||
} else {
|
||||
fmt.Println("Randomize every connection.")
|
||||
}
|
||||
conf := &Config{PM: pm}
|
||||
server, err := New(conf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Create SOCKS5 proxy on localhost port 8000
|
||||
if err := server.ListenAndServe("tcp", *bindAddr); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
367
go-socks-lb/request.go
Normal file
367
go-socks-lb/request.go
Normal file
@ -0,0 +1,367 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"golang.org/x/net/proxy"
|
||||
)
|
||||
|
||||
const (
|
||||
ConnectCommand = uint8(1)
|
||||
BindCommand = uint8(2)
|
||||
AssociateCommand = uint8(3)
|
||||
ipv4Address = uint8(1)
|
||||
fqdnAddress = uint8(3)
|
||||
ipv6Address = uint8(4)
|
||||
)
|
||||
|
||||
const (
|
||||
successReply uint8 = iota
|
||||
serverFailure
|
||||
ruleFailure
|
||||
networkUnreachable
|
||||
hostUnreachable
|
||||
connectionRefused
|
||||
ttlExpired
|
||||
commandNotSupported
|
||||
addrTypeNotSupported
|
||||
)
|
||||
|
||||
var (
|
||||
unrecognizedAddrType = fmt.Errorf("Unrecognized address type")
|
||||
)
|
||||
|
||||
// AddressRewriter is used to rewrite a destination transparently
|
||||
type AddressRewriter interface {
|
||||
Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec)
|
||||
}
|
||||
|
||||
// AddrSpec is used to return the target AddrSpec
|
||||
// which may be specified as IPv4, IPv6, or a FQDN
|
||||
type AddrSpec struct {
|
||||
FQDN string
|
||||
IP net.IP
|
||||
Port int
|
||||
}
|
||||
|
||||
func (a *AddrSpec) String() string {
|
||||
if a.FQDN != "" {
|
||||
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", a.IP, a.Port)
|
||||
}
|
||||
|
||||
// Address returns a string suitable to dial; prefer returning IP-based
|
||||
// address, fallback to FQDN
|
||||
func (a AddrSpec) Address() string {
|
||||
if 0 != len(a.IP) {
|
||||
return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port))
|
||||
}
|
||||
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
|
||||
}
|
||||
|
||||
// A Request represents request received by a server
|
||||
type Request struct {
|
||||
// Protocol version
|
||||
Version uint8
|
||||
// Requested command
|
||||
Command uint8
|
||||
// AuthContext provided during negotiation
|
||||
AuthContext *AuthContext
|
||||
// AddrSpec of the the network that sent the request
|
||||
RemoteAddr *AddrSpec
|
||||
// AddrSpec of the desired destination
|
||||
DestAddr *AddrSpec
|
||||
// AddrSpec of the actual destination (might be affected by rewrite)
|
||||
realDestAddr *AddrSpec
|
||||
bufConn io.Reader
|
||||
}
|
||||
|
||||
type conn interface {
|
||||
Write([]byte) (int, error)
|
||||
RemoteAddr() net.Addr
|
||||
}
|
||||
|
||||
// NewRequest creates a new Request from the tcp connection
|
||||
func NewRequest(bufConn io.Reader) (*Request, error) {
|
||||
// Read the version byte
|
||||
header := []byte{0, 0, 0}
|
||||
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
|
||||
return nil, fmt.Errorf("Failed to get command version: %v", err)
|
||||
}
|
||||
|
||||
// Ensure we are compatible
|
||||
if header[0] != socks5Version {
|
||||
return nil, fmt.Errorf("Unsupported command version: %v", header[0])
|
||||
}
|
||||
|
||||
// Read in the destination address
|
||||
dest, err := readAddrSpec(bufConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
request := &Request{
|
||||
Version: socks5Version,
|
||||
Command: header[1],
|
||||
DestAddr: dest,
|
||||
bufConn: bufConn,
|
||||
}
|
||||
|
||||
return request, nil
|
||||
}
|
||||
|
||||
// handleRequest is used for request processing after authentication
|
||||
func (s *Server) handleRequest(req *Request, conn conn) error {
|
||||
ctx := context.Background()
|
||||
|
||||
// Resolve the address if we have a FQDN
|
||||
// dest := req.DestAddr
|
||||
// if dest.FQDN != "" {
|
||||
// ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN)
|
||||
// if err != nil {
|
||||
// if err := sendReply(conn, hostUnreachable, nil); err != nil {
|
||||
// return fmt.Errorf("Failed to send reply: %v", err)
|
||||
// }
|
||||
// return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err)
|
||||
// }
|
||||
// ctx = ctx_
|
||||
// dest.IP = addr
|
||||
// }
|
||||
|
||||
// Apply any address rewrites
|
||||
req.realDestAddr = req.DestAddr
|
||||
if s.config.Rewriter != nil {
|
||||
ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req)
|
||||
}
|
||||
|
||||
// Switch on the command
|
||||
switch req.Command {
|
||||
case ConnectCommand:
|
||||
return s.handleConnect(ctx, conn, req)
|
||||
case BindCommand:
|
||||
return s.handleBind(ctx, conn, req)
|
||||
case AssociateCommand:
|
||||
return s.handleAssociate(ctx, conn, req)
|
||||
default:
|
||||
if err := sendReply(conn, commandNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Unsupported command: %v", req.Command)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnect is used to handle a connect command
|
||||
func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error {
|
||||
// Check if this is allowed
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
|
||||
} else {
|
||||
ctx = ctx_
|
||||
}
|
||||
|
||||
// Attempt to connect
|
||||
// dial := s.config.Dial
|
||||
dial := func(ctx context.Context, net_, addr string) (net.Conn, error) {
|
||||
var dp proxy.Dialer
|
||||
u, _ := url.Parse(s.config.PM.Get(addr))
|
||||
dp, _ = proxy.FromURL(u, proxy.Direct)
|
||||
return dp.Dial(net_, addr)
|
||||
}
|
||||
target, err := dial(ctx, "tcp", req.realDestAddr.Address())
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
resp := hostUnreachable
|
||||
if strings.Contains(msg, "refused") {
|
||||
resp = connectionRefused
|
||||
} else if strings.Contains(msg, "network is unreachable") {
|
||||
resp = networkUnreachable
|
||||
}
|
||||
if err := sendReply(conn, resp, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err)
|
||||
}
|
||||
defer target.Close()
|
||||
|
||||
// Send success
|
||||
local := target.LocalAddr().(*net.TCPAddr)
|
||||
bind := AddrSpec{IP: local.IP, Port: local.Port}
|
||||
if err := sendReply(conn, successReply, &bind); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
|
||||
// Start proxying
|
||||
errCh := make(chan error, 2)
|
||||
go proxy_conn(target, req.bufConn, errCh)
|
||||
go proxy_conn(conn, target, errCh)
|
||||
|
||||
// Wait
|
||||
for i := 0; i < 2; i++ {
|
||||
e := <-errCh
|
||||
if e != nil {
|
||||
// return from this function closes target (and conn).
|
||||
return e
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleBind is used to handle a connect command
|
||||
func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error {
|
||||
// Check if this is allowed
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
|
||||
} else {
|
||||
ctx = ctx_
|
||||
}
|
||||
|
||||
// TODO: Support bind
|
||||
if err := sendReply(conn, commandNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleAssociate is used to handle a connect command
|
||||
func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error {
|
||||
// Check if this is allowed
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
|
||||
} else {
|
||||
ctx = ctx_
|
||||
}
|
||||
|
||||
// TODO: Support associate
|
||||
if err := sendReply(conn, commandNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// readAddrSpec is used to read AddrSpec.
|
||||
// Expects an address type byte, follwed by the address and port
|
||||
func readAddrSpec(r io.Reader) (*AddrSpec, error) {
|
||||
d := &AddrSpec{}
|
||||
|
||||
// Get the address type
|
||||
addrType := []byte{0}
|
||||
if _, err := r.Read(addrType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle on a per type basis
|
||||
switch addrType[0] {
|
||||
case ipv4Address:
|
||||
addr := make([]byte, 4)
|
||||
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.IP = net.IP(addr)
|
||||
|
||||
case ipv6Address:
|
||||
addr := make([]byte, 16)
|
||||
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.IP = net.IP(addr)
|
||||
|
||||
case fqdnAddress:
|
||||
if _, err := r.Read(addrType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addrLen := int(addrType[0])
|
||||
fqdn := make([]byte, addrLen)
|
||||
if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.FQDN = string(fqdn)
|
||||
|
||||
default:
|
||||
return nil, unrecognizedAddrType
|
||||
}
|
||||
|
||||
// Read the port
|
||||
port := []byte{0, 0}
|
||||
if _, err := io.ReadAtLeast(r, port, 2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.Port = (int(port[0]) << 8) | int(port[1])
|
||||
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// sendReply is used to send a reply message
|
||||
func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
|
||||
// Format the address
|
||||
var addrType uint8
|
||||
var addrBody []byte
|
||||
var addrPort uint16
|
||||
switch {
|
||||
case addr == nil:
|
||||
addrType = ipv4Address
|
||||
addrBody = []byte{0, 0, 0, 0}
|
||||
addrPort = 0
|
||||
|
||||
case addr.FQDN != "":
|
||||
addrType = fqdnAddress
|
||||
addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...)
|
||||
addrPort = uint16(addr.Port)
|
||||
|
||||
case addr.IP.To4() != nil:
|
||||
addrType = ipv4Address
|
||||
addrBody = []byte(addr.IP.To4())
|
||||
addrPort = uint16(addr.Port)
|
||||
|
||||
case addr.IP.To16() != nil:
|
||||
addrType = ipv6Address
|
||||
addrBody = []byte(addr.IP.To16())
|
||||
addrPort = uint16(addr.Port)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("Failed to format address: %v", addr)
|
||||
}
|
||||
|
||||
// Format the message
|
||||
msg := make([]byte, 6+len(addrBody))
|
||||
msg[0] = socks5Version
|
||||
msg[1] = resp
|
||||
msg[2] = 0 // Reserved
|
||||
msg[3] = addrType
|
||||
copy(msg[4:], addrBody)
|
||||
msg[4+len(addrBody)] = byte(addrPort >> 8)
|
||||
msg[4+len(addrBody)+1] = byte(addrPort & 0xff)
|
||||
|
||||
// Send the message
|
||||
_, err := w.Write(msg)
|
||||
return err
|
||||
}
|
||||
|
||||
type closeWriter interface {
|
||||
CloseWrite() error
|
||||
}
|
||||
|
||||
// proxy is used to suffle data from src to destination, and sends errors
|
||||
// down a dedicated channel
|
||||
func proxy_conn(dst io.Writer, src io.Reader, errCh chan error) {
|
||||
_, err := io.Copy(dst, src)
|
||||
if tcpConn, ok := dst.(closeWriter); ok {
|
||||
tcpConn.CloseWrite()
|
||||
}
|
||||
errCh <- err
|
||||
}
|
||||
169
go-socks-lb/request_test.go
Normal file
169
go-socks-lb/request_test.go
Normal file
@ -0,0 +1,169 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type MockConn struct {
|
||||
buf bytes.Buffer
|
||||
}
|
||||
|
||||
func (m *MockConn) Write(b []byte) (int, error) {
|
||||
return m.buf.Write(b)
|
||||
}
|
||||
|
||||
func (m *MockConn) RemoteAddr() net.Addr {
|
||||
return &net.TCPAddr{IP: []byte{127, 0, 0, 1}, Port: 65432}
|
||||
}
|
||||
|
||||
func TestRequest_Connect(t *testing.T) {
|
||||
// Create a local listener
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
go func() {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(buf, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", buf)
|
||||
}
|
||||
conn.Write([]byte("pong"))
|
||||
}()
|
||||
lAddr := l.Addr().(*net.TCPAddr)
|
||||
|
||||
// Make server
|
||||
s := &Server{config: &Config{
|
||||
Rules: PermitAll(),
|
||||
Resolver: DNSResolver{},
|
||||
Logger: log.New(os.Stdout, "", log.LstdFlags),
|
||||
}}
|
||||
|
||||
// Create the connect request
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
||||
|
||||
port := []byte{0, 0}
|
||||
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
|
||||
buf.Write(port)
|
||||
|
||||
// Send a ping
|
||||
buf.Write([]byte("ping"))
|
||||
|
||||
// Handle the request
|
||||
resp := &MockConn{}
|
||||
req, err := NewRequest(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if err := s.handleRequest(req, resp); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Verify response
|
||||
out := resp.buf.Bytes()
|
||||
expected := []byte{
|
||||
5,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
127, 0, 0, 1,
|
||||
0, 0,
|
||||
'p', 'o', 'n', 'g',
|
||||
}
|
||||
|
||||
// Ignore the port for both
|
||||
out[8] = 0
|
||||
out[9] = 0
|
||||
|
||||
if !bytes.Equal(out, expected) {
|
||||
t.Fatalf("bad: %v %v", out, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequest_Connect_RuleFail(t *testing.T) {
|
||||
// Create a local listener
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
go func() {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(buf, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", buf)
|
||||
}
|
||||
conn.Write([]byte("pong"))
|
||||
}()
|
||||
lAddr := l.Addr().(*net.TCPAddr)
|
||||
|
||||
// Make server
|
||||
s := &Server{config: &Config{
|
||||
Rules: PermitNone(),
|
||||
Resolver: DNSResolver{},
|
||||
Logger: log.New(os.Stdout, "", log.LstdFlags),
|
||||
}}
|
||||
|
||||
// Create the connect request
|
||||
buf := bytes.NewBuffer(nil)
|
||||
buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
||||
|
||||
port := []byte{0, 0}
|
||||
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
|
||||
buf.Write(port)
|
||||
|
||||
// Send a ping
|
||||
buf.Write([]byte("ping"))
|
||||
|
||||
// Handle the request
|
||||
resp := &MockConn{}
|
||||
req, err := NewRequest(buf)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if err := s.handleRequest(req, resp); !strings.Contains(err.Error(), "blocked by rules") {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Verify response
|
||||
out := resp.buf.Bytes()
|
||||
expected := []byte{
|
||||
5,
|
||||
2,
|
||||
0,
|
||||
1,
|
||||
0, 0, 0, 0,
|
||||
0, 0,
|
||||
}
|
||||
|
||||
if !bytes.Equal(out, expected) {
|
||||
t.Fatalf("bad: %v %v", out, expected)
|
||||
}
|
||||
}
|
||||
23
go-socks-lb/resolver.go
Normal file
23
go-socks-lb/resolver.go
Normal file
@ -0,0 +1,23 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// NameResolver is used to implement custom name resolution
|
||||
type NameResolver interface {
|
||||
Resolve(ctx context.Context, name string) (context.Context, net.IP, error)
|
||||
}
|
||||
|
||||
// DNSResolver uses the system DNS to resolve host names
|
||||
type DNSResolver struct{}
|
||||
|
||||
func (d DNSResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
|
||||
addr, err := net.ResolveIPAddr("ip", name)
|
||||
if err != nil {
|
||||
return ctx, nil, err
|
||||
}
|
||||
return ctx, addr.IP, err
|
||||
}
|
||||
21
go-socks-lb/resolver_test.go
Normal file
21
go-socks-lb/resolver_test.go
Normal file
@ -0,0 +1,21 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestDNSResolver(t *testing.T) {
|
||||
d := DNSResolver{}
|
||||
ctx := context.Background()
|
||||
|
||||
_, addr, err := d.Resolve(ctx, "localhost")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if !addr.IsLoopback() {
|
||||
t.Fatalf("expected loopback")
|
||||
}
|
||||
}
|
||||
41
go-socks-lb/ruleset.go
Normal file
41
go-socks-lb/ruleset.go
Normal file
@ -0,0 +1,41 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// RuleSet is used to provide custom rules to allow or prohibit actions
|
||||
type RuleSet interface {
|
||||
Allow(ctx context.Context, req *Request) (context.Context, bool)
|
||||
}
|
||||
|
||||
// PermitAll returns a RuleSet which allows all types of connections
|
||||
func PermitAll() RuleSet {
|
||||
return &PermitCommand{true, true, true}
|
||||
}
|
||||
|
||||
// PermitNone returns a RuleSet which disallows all types of connections
|
||||
func PermitNone() RuleSet {
|
||||
return &PermitCommand{false, false, false}
|
||||
}
|
||||
|
||||
// PermitCommand is an implementation of the RuleSet which
|
||||
// enables filtering supported commands
|
||||
type PermitCommand struct {
|
||||
EnableConnect bool
|
||||
EnableBind bool
|
||||
EnableAssociate bool
|
||||
}
|
||||
|
||||
func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) {
|
||||
switch req.Command {
|
||||
case ConnectCommand:
|
||||
return ctx, p.EnableConnect
|
||||
case BindCommand:
|
||||
return ctx, p.EnableBind
|
||||
case AssociateCommand:
|
||||
return ctx, p.EnableAssociate
|
||||
}
|
||||
|
||||
return ctx, false
|
||||
}
|
||||
24
go-socks-lb/ruleset_test.go
Normal file
24
go-socks-lb/ruleset_test.go
Normal file
@ -0,0 +1,24 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestPermitCommand(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
r := &PermitCommand{true, false, false}
|
||||
|
||||
if _, ok := r.Allow(ctx, &Request{Command: ConnectCommand}); !ok {
|
||||
t.Fatalf("expect connect")
|
||||
}
|
||||
|
||||
if _, ok := r.Allow(ctx, &Request{Command: BindCommand}); ok {
|
||||
t.Fatalf("do not expect bind")
|
||||
}
|
||||
|
||||
if _, ok := r.Allow(ctx, &Request{Command: AssociateCommand}); ok {
|
||||
t.Fatalf("do not expect associate")
|
||||
}
|
||||
}
|
||||
172
go-socks-lb/socks5.go
Normal file
172
go-socks-lb/socks5.go
Normal file
@ -0,0 +1,172 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
const (
|
||||
socks5Version = uint8(5)
|
||||
)
|
||||
|
||||
// Config is used to setup and configure a Server
|
||||
type Config struct {
|
||||
// AuthMethods can be provided to implement custom authentication
|
||||
// By default, "auth-less" mode is enabled.
|
||||
// For password-based auth use UserPassAuthenticator.
|
||||
AuthMethods []Authenticator
|
||||
|
||||
// If provided, username/password authentication is enabled,
|
||||
// by appending a UserPassAuthenticator to AuthMethods. If not provided,
|
||||
// and AUthMethods is nil, then "auth-less" mode is enabled.
|
||||
Credentials CredentialStore
|
||||
|
||||
// Resolver can be provided to do custom name resolution.
|
||||
// Defaults to DNSResolver if not provided.
|
||||
Resolver NameResolver
|
||||
|
||||
// Rules is provided to enable custom logic around permitting
|
||||
// various commands. If not provided, PermitAll is used.
|
||||
Rules RuleSet
|
||||
|
||||
// Rewriter can be used to transparently rewrite addresses.
|
||||
// This is invoked before the RuleSet is invoked.
|
||||
// Defaults to NoRewrite.
|
||||
Rewriter AddressRewriter
|
||||
|
||||
// BindIP is used for bind or udp associate
|
||||
BindIP net.IP
|
||||
|
||||
// Logger can be used to provide a custom log target.
|
||||
// Defaults to stdout.
|
||||
Logger *log.Logger
|
||||
|
||||
// Managing proxy selections
|
||||
PM *ProxyManager
|
||||
|
||||
// Optional function for dialing out
|
||||
Dial func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// Server is reponsible for accepting connections and handling
|
||||
// the details of the SOCKS5 protocol
|
||||
type Server struct {
|
||||
config *Config
|
||||
authMethods map[uint8]Authenticator
|
||||
}
|
||||
|
||||
// New creates a new Server and potentially returns an error
|
||||
func New(conf *Config) (*Server, error) {
|
||||
// Ensure we have at least one authentication method enabled
|
||||
if len(conf.AuthMethods) == 0 {
|
||||
if conf.Credentials != nil {
|
||||
conf.AuthMethods = []Authenticator{&UserPassAuthenticator{conf.Credentials}}
|
||||
} else {
|
||||
conf.AuthMethods = []Authenticator{&NoAuthAuthenticator{}}
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure we have a DNS resolver
|
||||
if conf.Resolver == nil {
|
||||
conf.Resolver = DNSResolver{}
|
||||
}
|
||||
|
||||
// Ensure we have a rule set
|
||||
if conf.Rules == nil {
|
||||
conf.Rules = PermitAll()
|
||||
}
|
||||
|
||||
// Ensure we have a log target
|
||||
if conf.Logger == nil {
|
||||
conf.Logger = log.New(os.Stdout, "", log.LstdFlags)
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
config: conf,
|
||||
}
|
||||
|
||||
server.authMethods = make(map[uint8]Authenticator)
|
||||
|
||||
for _, a := range conf.AuthMethods {
|
||||
server.authMethods[a.GetCode()] = a
|
||||
}
|
||||
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// ListenAndServe is used to create a listener and serve on it
|
||||
func (s *Server) ListenAndServe(network, addr string) error {
|
||||
l, err := net.Listen(network, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Serve(l)
|
||||
}
|
||||
|
||||
// Serve is used to serve connections from a listener
|
||||
func (s *Server) Serve(l net.Listener) error {
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.ServeConn(conn)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ServeConn is used to serve a single connection.
|
||||
func (s *Server) ServeConn(conn net.Conn) error {
|
||||
defer conn.Close()
|
||||
bufConn := bufio.NewReader(conn)
|
||||
|
||||
// Read the version byte
|
||||
version := []byte{0}
|
||||
if _, err := bufConn.Read(version); err != nil {
|
||||
s.config.Logger.Printf("[ERR] socks: Failed to get version byte: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure we are compatible
|
||||
if version[0] != socks5Version {
|
||||
err := fmt.Errorf("Unsupported SOCKS version: %v", version)
|
||||
s.config.Logger.Printf("[ERR] socks: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Authenticate the connection
|
||||
authContext, err := s.authenticate(conn, bufConn)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Failed to authenticate: %v", err)
|
||||
s.config.Logger.Printf("[ERR] socks: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
request, err := NewRequest(bufConn)
|
||||
if err != nil {
|
||||
if err == unrecognizedAddrType {
|
||||
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil {
|
||||
return fmt.Errorf("Failed to send reply: %v", err)
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("Failed to read destination address: %v", err)
|
||||
}
|
||||
request.AuthContext = authContext
|
||||
if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok {
|
||||
request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port}
|
||||
}
|
||||
|
||||
// Process the client request
|
||||
if err := s.handleRequest(request, conn); err != nil {
|
||||
err = fmt.Errorf("Failed to handle request: %v", err)
|
||||
s.config.Logger.Printf("[ERR] socks: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
110
go-socks-lb/socks5_test.go
Normal file
110
go-socks-lb/socks5_test.go
Normal file
@ -0,0 +1,110 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSOCKS5_Connect(t *testing.T) {
|
||||
// Create a local listener
|
||||
l, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
go func() {
|
||||
conn, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(buf, []byte("ping")) {
|
||||
t.Fatalf("bad: %v", buf)
|
||||
}
|
||||
conn.Write([]byte("pong"))
|
||||
}()
|
||||
lAddr := l.Addr().(*net.TCPAddr)
|
||||
|
||||
// Create a socks server
|
||||
creds := StaticCredentials{
|
||||
"foo": "bar",
|
||||
}
|
||||
cator := UserPassAuthenticator{Credentials: creds}
|
||||
conf := &Config{
|
||||
AuthMethods: []Authenticator{cator},
|
||||
Logger: log.New(os.Stdout, "", log.LstdFlags),
|
||||
}
|
||||
serv, err := New(conf)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Start listening
|
||||
go func() {
|
||||
if err := serv.ListenAndServe("tcp", "127.0.0.1:12365"); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
}()
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Get a local conn
|
||||
conn, err := net.Dial("tcp", "127.0.0.1:12365")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Connect, auth and connec to local
|
||||
req := bytes.NewBuffer(nil)
|
||||
req.Write([]byte{5})
|
||||
req.Write([]byte{2, NoAuth, UserPassAuth})
|
||||
req.Write([]byte{1, 3, 'f', 'o', 'o', 3, 'b', 'a', 'r'})
|
||||
req.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
|
||||
|
||||
port := []byte{0, 0}
|
||||
binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
|
||||
req.Write(port)
|
||||
|
||||
// Send a ping
|
||||
req.Write([]byte("ping"))
|
||||
|
||||
// Send all the bytes
|
||||
conn.Write(req.Bytes())
|
||||
|
||||
// Verify response
|
||||
expected := []byte{
|
||||
socks5Version, UserPassAuth,
|
||||
1, authSuccess,
|
||||
5,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
127, 0, 0, 1,
|
||||
0, 0,
|
||||
'p', 'o', 'n', 'g',
|
||||
}
|
||||
out := make([]byte, len(expected))
|
||||
|
||||
conn.SetDeadline(time.Now().Add(time.Second))
|
||||
if _, err := io.ReadAtLeast(conn, out, len(out)); err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
|
||||
// Ignore the port
|
||||
out[12] = 0
|
||||
out[13] = 0
|
||||
|
||||
if !bytes.Equal(out, expected) {
|
||||
t.Fatalf("bad: %v", out)
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user