diff --git a/cmd/client.go b/cmd/client.go index 66c7a94..b231706 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -15,17 +15,23 @@ var clientCmd = &cobra.Command{ Short: "Initiates a local proxy to the remote server", Long: "Initiates a local proxy to the remote server", Run: func(cmd *cobra.Command, args []string) { - port, err := cmd.Flags().GetUint16("port") + serverAddr, err := cmd.Flags().GetString("serverAddr") if err != nil { panic(err) } - client.Start(port) + + localAddr, err := cmd.Flags().GetString("localAddr") + if err != nil { + panic(err) + } + + client.Start(serverAddr, localAddr) }, } func init() { rootCmd.AddCommand(clientCmd) - clientCmd.Flags().Uint16P("port", "p", 0, "Local port to expose") - clientCmd.MarkFlagRequired("port") + clientCmd.Flags().StringP("serverAddr", "s", "trok.tux.rs:1337", "Remote server address") + clientCmd.Flags().StringP("localAddr", "a", "0.0.0.0:80", "Local addr to expose") } diff --git a/internal/client/init.go b/internal/client/init.go index d658340..38c3d3a 100644 --- a/internal/client/init.go +++ b/internal/client/init.go @@ -11,9 +11,9 @@ import ( "github.com/rs/zerolog/log" ) -func Start(port uint16) { - var trok Trok - if err := trok.Init(port); err != nil { +func Start(serverAddr, localAddr string) { + trok, err := NewTrokClient(serverAddr, localAddr) + if err != nil { log.Fatal().Msgf("failed init trok %v", err) } diff --git a/internal/client/tcp.go b/internal/client/tcp.go index 8c509e2..575d449 100644 --- a/internal/client/tcp.go +++ b/internal/client/tcp.go @@ -5,7 +5,6 @@ Copyright © 2024 tux <0xtux@pm.me> package client import ( - "fmt" "net" ) @@ -14,16 +13,13 @@ type TCPClient struct { conn net.Conn } -func (c *TCPClient) Init(port uint16, title string) error { - address := fmt.Sprintf(":%d", port) - conn, err := net.Dial("tcp", address) - if err != nil { - return err - } +func NewTCPClient(addr, title string) (*TCPClient, error) { + conn, err := net.Dial("tcp", addr) - c.title = title - c.conn = conn - return nil + return &TCPClient{ + title: title, + conn: conn, + }, err } func (c *TCPClient) Start(handler func(conn net.Conn)) { @@ -34,6 +30,14 @@ func (c *TCPClient) Stop() error { return c.conn.Close() } +func (s *TCPClient) Addr() string { + return s.conn.RemoteAddr().String() +} + +func (c *TCPClient) Host() string { + return c.conn.RemoteAddr().(*net.TCPAddr).IP.String() +} + func (c *TCPClient) Port() uint16 { return uint16(c.conn.RemoteAddr().(*net.TCPAddr).Port) } diff --git a/internal/client/trok.go b/internal/client/trok.go index da273b0..a7d2609 100644 --- a/internal/client/trok.go +++ b/internal/client/trok.go @@ -14,22 +14,29 @@ import ( ) type Trok struct { - controlClient TCPClient + controlClient *TCPClient + serverAddr string + localAddr string } -func (t *Trok) Init(port uint16) error { - err := t.controlClient.Init(port, "Controller") - return err +func NewTrokClient(serverAddr, localAddr string) (*Trok, error) { + controlClient, err := NewTCPClient(serverAddr, "Controller") + + return &Trok{ + controlClient: controlClient, + serverAddr: serverAddr, + localAddr: localAddr, + }, err } func (t *Trok) Start() { go t.controlClient.Start(t.ControlConnHandler) - log.Info().Msgf("started Trok client on port %d", t.controlClient.Port()) + log.Info().Msgf("started Trok client on %s", t.controlClient.Addr()) } func (t *Trok) Stop() { t.controlClient.Stop() - log.Info().Msgf("stopped Trok client on port %d", t.controlClient.Port()) + log.Info().Msgf("stopped Trok client on %s", t.controlClient.Addr()) } func (t *Trok) ControlConnHandler(conn net.Conn) { @@ -69,22 +76,19 @@ func (t *Trok) hanldeCMDEHLO(m *lib.Message) { func (t *Trok) handleCMDCNCT(m *lib.Message) { log.Info().Msgf("[CMD] %s [ARG] %s", m.CMD, m.ARG) - var upstream TCPClient - var downstream TCPClient - - err := upstream.Init(3000, "UpStream") + upStream, err := NewTCPClient(t.localAddr, "UpStream") if err != nil { log.Error().Msgf("can't connect to upstream socket: %v", err) return } - err = downstream.Init(1421, "DownStream") + downStream, err := NewTCPClient(t.serverAddr, "DownStream") if err != nil { log.Error().Msgf("can't connect to downstream socket: %v", err) return } - downstream.conn.Write([]byte(fmt.Sprintf("ACPT %s\n", m.ARG))) - go io.Copy(upstream.conn, downstream.conn) - io.Copy(downstream.conn, upstream.conn) + downStream.conn.Write([]byte(fmt.Sprintf("ACPT %s\n", m.ARG))) + go io.Copy(upStream.conn, downStream.conn) + io.Copy(downStream.conn, upStream.conn) }