借助io.Copy非常简单
type Endpoint struct {
Host string
Port int
}
func (endpoint *Endpoint) String() string {
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
}
type SSHtunnel struct {
Local *Endpoint
Server *Endpoint
Remote *Endpoint
Config *ssh.ClientConfig
}
func (tunnel *SSHtunnel) Start() error {
listener, err := net.Listen("tcp", tunnel.Local.String())
if err != nil {
return err
}
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
return err
}
go tunnel.forward(conn)
}
}
func (tunnel *SSHtunnel) forward(localConn net.Conn) {
serverConn, err := ssh.Dial("tcp", tunnel.Server.String(), tunnel.Config)
if err != nil {
fmt.Printf("Server dial error: %s\n", err)
return
}
remoteConn, err := serverConn.Dial("tcp", tunnel.Remote.String())
if err != nil {
fmt.Printf("Remote dial error: %s\n", err)
return
}
copyConn := func(writer, reader net.Conn) {
defer writer.Close()
defer reader.Close()
_, err := io.Copy(writer, reader)
if err != nil {
fmt.Printf("io.Copy error: %s", err)
}
}
go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
}
func main() {
localEndpoint := &Endpoint{
Host: "localhost",
Port: 9000,
}
serverEndpoint := &Endpoint{
Host: "some-real-ssh-listening-port",
Port: 22,
}
remoteEndpoint := &Endpoint{
Host: "www.baidu.com",
Port: 80,
}
sshConfig := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{
ssh.Password("real-password"),
},
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
// Always accept key.
return nil
},
}
tunnel := &SSHtunnel{
Config: sshConfig,
Local: localEndpoint,
Server: serverEndpoint,
Remote: remoteEndpoint,
}
tunnel.Start()
}