diff --git a/cmd/server.go b/cmd/server.go index 46f859b..dc1b1be 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -15,16 +15,17 @@ var serverCmd = &cobra.Command{ Short: "Initiates the remote proxy server", Long: "Initiates the remote proxy server", Run: func(cmd *cobra.Command, args []string) { - port, err := cmd.Flags().GetUint16("port") + addr, err := cmd.Flags().GetString("addr") if err != nil { panic(err) } - server.Start(port) + server.Start(addr) }, } func init() { rootCmd.AddCommand(serverCmd) - serverCmd.Flags().Uint16P("port", "p", 1421, "Port for the server to listen on") + serverCmd.Flags().StringP("addr", "a", "0.0.0.0:1337", "Addr for the server to listen on") + clientCmd.MarkFlagRequired("addr") } diff --git a/internal/server/init.go b/internal/server/init.go index f72b1bb..e7746e0 100644 --- a/internal/server/init.go +++ b/internal/server/init.go @@ -11,9 +11,9 @@ import ( "github.com/rs/zerolog/log" ) -func Start(port uint16) { +func Start(addr string) { var trok Trok - if err := trok.Init(port); err != nil { + if err := trok.Init(addr); err != nil { log.Fatal().Msgf("failed to init trok %v", err) } diff --git a/internal/server/tcp.go b/internal/server/tcp.go index 22be89d..f7da2ec 100644 --- a/internal/server/tcp.go +++ b/internal/server/tcp.go @@ -5,7 +5,6 @@ Copyright © 2024 tux <0xtux@pm.me> package server import ( - "fmt" "net" ) @@ -14,13 +13,8 @@ type TCPServer struct { listener net.Listener } -func (s *TCPServer) Init(port uint16, title string) error { - address := ":" - if port > 0 { - address = fmt.Sprintf(":%d", port) - } - - ln, err := net.Listen("tcp", address) +func (s *TCPServer) Init(addr, title string) error { + ln, err := net.Listen("tcp", addr) if err != nil { return err } @@ -44,6 +38,14 @@ func (s *TCPServer) Stop() error { return s.listener.Close() } +func (s *TCPServer) Addr() string { + return s.listener.Addr().String() +} + +func (c *TCPServer) Host() string { + return c.listener.Addr().(*net.TCPAddr).IP.String() +} + func (s *TCPServer) Port() uint16 { return uint16(s.listener.Addr().(*net.TCPAddr).Port) } diff --git a/internal/server/trok.go b/internal/server/trok.go index f62c0fb..5a89f33 100644 --- a/internal/server/trok.go +++ b/internal/server/trok.go @@ -27,20 +27,20 @@ type Trok struct { mutex sync.Mutex } -func (t *Trok) Init(port uint16) error { +func (t *Trok) Init(addr string) error { t.publicConns = make(map[string]Conn) - err := t.controlServer.Init(port, "Controller") + err := t.controlServer.Init(addr, "Controller") return err } func (t *Trok) Start() { go t.controlServer.Start(t.ControlConnHandler) - log.Info().Msgf("started Trok server on port %d", t.controlServer.Port()) + log.Info().Msgf("started Trok server on %s", t.controlServer.Addr()) } func (t *Trok) Stop() { t.controlServer.Stop() - log.Info().Msgf("stopped Trok server on port %d", t.controlServer.Port()) + log.Info().Msgf("stopped Trok server on %s", t.controlServer.Addr()) } func (t *Trok) ControlConnHandler(conn net.Conn) { @@ -75,7 +75,7 @@ func (t *Trok) handleCMDHELO(p *lib.ProtocolHandler, m *lib.Message) { log.Info().Msgf("[CMD] %s [ARG] %s", m.CMD, m.ARG) var s TCPServer - err := s.Init(0, "Handler") + err := s.Init(":", "Handler") if err != nil { log.Error().Msgf("error handling HELO cmd: %v", err) return