package modbus import ( "crypto/tls" "crypto/x509" "encoding/asn1" "errors" "fmt" "log" "net" "strings" "sync" "time" ) // Modbus Role PEM OID (see R-21 of the MBAPS spec) var modbusRoleOID asn1.ObjectIdentifier = asn1.ObjectIdentifier{ 1, 3, 6, 1, 4, 1, 50316, 802, 1, } // Server configuration object. type ServerConfiguration struct { // URL defines where to listen at e.g. tcp://[::]:502 URL string // Timeout sets the idle session timeout (client connections will // be closed if idle for this long) Timeout time.Duration // MaxClients sets the maximum number of concurrent client connections MaxClients uint // TLSServerCert sets the server-side TLS key pair (tcp+tls only) TLSServerCert *tls.Certificate // TLSClientCAs sets the list of CA certificates used to authenticate // client connections (tcp+tls only). Leaf (i.e. client) certificates can // also be used in case of self-signed certs, or if cert pinning is required. TLSClientCAs *x509.CertPool // Logger provides a custom sink for log messages. // If nil, messages will be written to stdout. Logger *log.Logger } // Request object passed to the coil handler. type CoilsRequest struct { ClientAddr string // the source (client) IP address ClientRole string // the client role as encoded in the client certificate (tcp+tls only) UnitId uint8 // the requested unit id (slave id) Addr uint16 // the base coil address requested Quantity uint16 // the number of consecutive coils covered by this request // (first address: Addr, last address: Addr + Quantity - 1) IsWrite bool // true if the request is a write, false if a read Args []bool // a slice of bool values of the coils to be set, ordered // from Addr to Addr + Quantity - 1 (for writes only) } // Request object passed to the discrete input handler. type DiscreteInputsRequest struct { ClientAddr string // the source (client) IP address ClientRole string // the client role as encoded in the client certificate (tcp+tls only) UnitId uint8 // the requested unit id (slave id) Addr uint16 // the base discrete input address requested Quantity uint16 // the number of consecutive discrete inputs covered by this request } // Request object passed to the holding register handler. type HoldingRegistersRequest struct { ClientAddr string // the source (client) IP address ClientRole string // the client role as encoded in the client certificate (tcp+tls only) UnitId uint8 // the requested unit id (slave id) Addr uint16 // the base register address requested Quantity uint16 // the number of consecutive registers covered by this request IsWrite bool // true if the request is a write, false if a read Args []uint16 // a slice of register values to be set, ordered from // Addr to Addr + Quantity - 1 (for writes only) } // Request object passed to the input register handler. type InputRegistersRequest struct { ClientAddr string // the source (client) IP address ClientRole string // the client role as encoded in the client certificate (tcp+tls only) UnitId uint8 // the requested unit id (slave id) Addr uint16 // the base register address requested Quantity uint16 // the number of consecutive registers covered by this request } // The RequestHandler interface should be implemented by the handler // object passed to NewServer (see reqHandler in NewServer()). // After decoding and validating an incoming request, the server will // invoke the appropriate handler function, depending on the function code // of the request. type RequestHandler interface { // HandleCoils handles the read coils (0x01), write single coil (0x05) // and write multiple coils (0x0f) function codes. // A CoilsRequest object is passed to the handler (see above). // // Expected return values: // - res: a slice of bools containing the coil values to be sent to back // to the client (only sent for reads), // - err: either nil if no error occurred, a modbus error (see // mapErrorToExceptionCode() in modbus.go for a complete list), // or any other error. // If nil, a positive modbus response is sent back to the client // along with the returned data. // If non-nil, a negative modbus response is sent back, with the // exception code set depending on the error // (again, see mapErrorToExceptionCode()). HandleCoils (req *CoilsRequest) (res []bool, err error) // HandleDiscreteInputs handles the read discrete inputs (0x02) function code. // A DiscreteInputsRequest oibject is passed to the handler (see above). // // Expected return values: // - res: a slice of bools containing the discrete input values to be // sent back to the client, // - err: either nil if no error occurred, a modbus error (see // mapErrorToExceptionCode() in modbus.go for a complete list), // or any other error. HandleDiscreteInputs (req *DiscreteInputsRequest) (res []bool, err error) // HandleHoldingRegisters handles the read holding registers (0x03), // write single register (0x06) and write multiple registers (0x10). // A HoldingRegistersRequest object is passed to the handler (see above). // // Expected return values: // - res: a slice of uint16 containing the register values to be sent // to back to the client (only sent for reads), // - err: either nil if no error occurred, a modbus error (see // mapErrorToExceptionCode() in modbus.go for a complete list), // or any other error. HandleHoldingRegisters (req *HoldingRegistersRequest) (res []uint16, err error) // HandleInputRegisters handles the read input registers (0x04) function code. // An InputRegistersRequest object is passed to the handler (see above). // // Expected return values: // - res: a slice of uint16 containing the register values to be sent // back to the client, // - err: either nil if no error occurred, a modbus error (see // mapErrorToExceptionCode() in modbus.go for a complete list), // or any other error. HandleInputRegisters (req *InputRegistersRequest) (res []uint16, err error) } // Modbus server object. type ModbusServer struct { conf ServerConfiguration logger *logger lock sync.Mutex started bool handler RequestHandler tcpListener net.Listener tcpClients []net.Conn transportType transportType } // Returns a new modbus server. // reqHandler should be a user-provided handler object satisfying the RequestHandler // interface. func NewServer(conf *ServerConfiguration, reqHandler RequestHandler) ( ms *ModbusServer, err error) { var serverType string var splitURL []string ms = &ModbusServer{ conf: *conf, handler: reqHandler, } splitURL = strings.SplitN(ms.conf.URL, "://", 2) if len(splitURL) == 2 { serverType = splitURL[0] ms.conf.URL = splitURL[1] } ms.logger = newLogger( fmt.Sprintf("modbus-server(%s)", ms.conf.URL), ms.conf.Logger) if ms.conf.URL == "" { ms.logger.Errorf("missing host part in URL '%s'", conf.URL) err = ErrConfigurationError return } switch serverType { case "tcp": if ms.conf.Timeout == 0 { ms.conf.Timeout = 120 * time.Second } if ms.conf.MaxClients == 0 { ms.conf.MaxClients = 10 } ms.transportType = modbusTCP case "tcp+tls": if ms.conf.Timeout == 0 { ms.conf.Timeout = 120 * time.Second } if ms.conf.MaxClients == 0 { ms.conf.MaxClients = 10 } // expect a server-side certificate if ms.conf.TLSServerCert == nil { ms.logger.Errorf("missing server certificate") err = ErrConfigurationError return } // expect a CertPool object containing at least 1 CA or // leaf certificate to validate client-side certificates if ms.conf.TLSClientCAs == nil { ms.logger.Errorf("missing CA/client certificates") err = ErrConfigurationError return } ms.transportType = modbusTCPOverTLS default: err = ErrConfigurationError return } return } // Starts accepting client connections. func (ms *ModbusServer) Start() (err error) { ms.lock.Lock() defer ms.lock.Unlock() if ms.started { return } switch ms.transportType { case modbusTCP, modbusTCPOverTLS: // bind to a TCP socket ms.tcpListener, err = net.Listen("tcp", ms.conf.URL) if err != nil { return } // accept client connections in a goroutine go ms.acceptTCPClients() default: err = ErrConfigurationError return } ms.started = true return } // Stops accepting new client connections and closes any active session. func (ms *ModbusServer) Stop() (err error) { ms.lock.Lock() defer ms.lock.Unlock() if !ms.started { return } ms.started = false if ms.transportType == modbusTCP || ms.transportType == modbusTCPOverTLS { // close the server socket if we're listening over TCP err = ms.tcpListener.Close() // close all active TCP clients for _, sock := range ms.tcpClients{ sock.Close() } } return } // Accepts new client connections if the configured connection limit allows it. // Each connection is served from a dedicated goroutine to allow for concurrent // connections. func (ms *ModbusServer) acceptTCPClients() { var sock net.Conn var err error var accepted bool for { sock, err = ms.tcpListener.Accept() if err != nil { // if the server socket has just been closed, return here as // this goroutine isn't going to see any new client connection if errors.Is(err, net.ErrClosed) { return } ms.logger.Warningf("failed to accept client connection: %v", err) continue } ms.lock.Lock() // apply a connection limit if ms.started && uint(len(ms.tcpClients)) < ms.conf.MaxClients { accepted = true // add the new client connection to the pool ms.tcpClients = append(ms.tcpClients, sock) } else { accepted = false } ms.lock.Unlock() if accepted { // spin a client handler goroutine to serve the new client go ms.handleTCPClient(sock) } else { ms.logger.Warningf("max. number of concurrent connections " + "reached, rejecting %v", sock.RemoteAddr()) // discard the connection sock.Close() } } // never reached return } // Handles a TCP client connection. // Once handleTransport() returns (i.e. the connection has either closed, timed // out, or an unrecoverable error happened), the TCP socket is closed and removed // from the list of active client connections. func (ms *ModbusServer) handleTCPClient(sock net.Conn) { var err error var clientRole string var tlsSock net.Conn switch ms.transportType { case modbusTCP: // serve modbus requests over the raw TCP connection ms.handleTransport( newTCPTransport(sock, ms.conf.Timeout, ms.conf.Logger), sock.RemoteAddr().String(), "") case modbusTCPOverTLS: // start TLS negotiation over the raw TCP connection tlsSock, clientRole, err = ms.startTLS(sock) if err != nil { ms.logger.Warningf("TLS handshake with %s failed: %v", sock.RemoteAddr().String(), err) } else { // serve modbus requests over the TLS tunnel ms.handleTransport( newTCPTransport(tlsSock, ms.conf.Timeout, ms.conf.Logger), sock.RemoteAddr().String(), clientRole) } default: ms.logger.Errorf("unimplemented transport type %v", ms.transportType) } // once done, remove our connection from the list of active client conns ms.lock.Lock() for i := range ms.tcpClients { if ms.tcpClients[i] == sock { ms.tcpClients[i] = ms.tcpClients[len(ms.tcpClients)-1] ms.tcpClients = ms.tcpClients[:len(ms.tcpClients)-1] break } } ms.lock.Unlock() // close the connection sock.Close() return } // For each request read from the transport, performs decoding and validation, // calls the user-provided handler, then encodes and writes the response // to the transport. func (ms *ModbusServer) handleTransport(t transport, clientAddr string, clientRole string) { var req *pdu var res *pdu var err error var addr uint16 var quantity uint16 for { req, err = t.ReadRequest() if err != nil { return } switch req.functionCode { case fcReadCoils, fcReadDiscreteInputs: var coils []bool var resCount int if len(req.payload) != 4 { err = ErrProtocolError break } // decode address and quantity fields addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2]) quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4]) // ensure the reply never exceeds the maximum PDU length and we // never read past 0xffff if quantity > 2000 || quantity == 0 { err = ErrProtocolError break } if uint32(addr) + uint32(quantity) - 1 > 0xffff { err = ErrIllegalDataAddress break } // invoke the appropriate handler if req.functionCode == fcReadCoils { coils, err = ms.handler.HandleCoils(&CoilsRequest{ ClientAddr: clientAddr, ClientRole: clientRole, UnitId: req.unitId, Addr: addr, Quantity: quantity, IsWrite: false, Args: nil, }) } else { coils, err = ms.handler.HandleDiscreteInputs( &DiscreteInputsRequest{ ClientAddr: clientAddr, ClientRole: clientRole, UnitId: req.unitId, Addr: addr, Quantity: quantity, }) } resCount = len(coils) // make sure the handler returned the expected number of items if err == nil && resCount != int(quantity) { ms.logger.Errorf("handler returned %v bools, " + "expected %v", resCount, quantity) err = ErrServerDeviceFailure break } if err != nil { break } // assemble a response PDU res = &pdu{ unitId: req.unitId, functionCode: req.functionCode, payload: []byte{0}, } // byte count (1 byte for 8 coils) res.payload[0] = uint8(resCount / 8) if resCount % 8 != 0 { res.payload[0]++ } // coil values res.payload = append(res.payload, encodeBools(coils)...) case fcWriteSingleCoil: if len(req.payload) != 4 { err = ErrProtocolError break } // decode the address field addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2]) // validate the value field (should be either 0xff00 or 0x0000) if ((req.payload[2] != 0xff && req.payload[2] != 0x00) || req.payload[3] != 0x00) { err = ErrProtocolError break } // invoke the coil handler _, err = ms.handler.HandleCoils(&CoilsRequest{ ClientAddr: clientAddr, ClientRole: clientRole, UnitId: req.unitId, Addr: addr, Quantity: 1, // request for a single coil IsWrite: true, // this is a write request Args: []bool{(req.payload[2] == 0xff)}, }) if err != nil { break } // assemble a response PDU res = &pdu{ unitId: req.unitId, functionCode: req.functionCode, } // echo the address and value in the response res.payload = append(res.payload, uint16ToBytes(BIG_ENDIAN, addr)...) res.payload = append(res.payload, req.payload[2], req.payload[3]) case fcWriteMultipleCoils: var expectedLen int if len(req.payload) < 6 { err = ErrProtocolError break } // decode address and quantity fields addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2]) quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4]) // ensure the reply never exceeds the maximum PDU length and we // never read past 0xffff if quantity > 0x7b0 || quantity == 0 { err = ErrProtocolError break } if uint32(addr) + uint32(quantity) - 1 > 0xffff { err = ErrIllegalDataAddress break } // validate the byte count field (1 byte for 8 coils) expectedLen = int(quantity) / 8 if quantity % 8 != 0 { expectedLen++ } if req.payload[4] != uint8(expectedLen) { err = ErrProtocolError break } // make sure we have enough bytes if len(req.payload) - 5 != expectedLen { err = ErrProtocolError break } // invoke the coil handler _, err = ms.handler.HandleCoils(&CoilsRequest{ ClientAddr: clientAddr, ClientRole: clientRole, UnitId: req.unitId, Addr: addr, Quantity: quantity, IsWrite: true, // this is a write request Args: decodeBools(quantity, req.payload[5:]), }) if err != nil { break } // assemble a response PDU res = &pdu{ unitId: req.unitId, functionCode: req.functionCode, } // echo the address and quantity in the response res.payload = append(res.payload, uint16ToBytes(BIG_ENDIAN, addr)...) res.payload = append(res.payload, uint16ToBytes(BIG_ENDIAN, quantity)...) case fcReadHoldingRegisters, fcReadInputRegisters: var regs []uint16 var resCount int if len(req.payload) != 4 { err = ErrProtocolError break } // decode address and quantity fields addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2]) quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4]) // ensure the reply never exceeds the maximum PDU length and we // never read past 0xffff if quantity > 0x007d || quantity == 0 { err = ErrProtocolError break } if uint32(addr) + uint32(quantity) - 1 > 0xffff { err = ErrIllegalDataAddress break } // invoke the appropriate handler if req.functionCode == fcReadHoldingRegisters { regs, err = ms.handler.HandleHoldingRegisters( &HoldingRegistersRequest{ ClientAddr: clientAddr, ClientRole: clientRole, UnitId: req.unitId, Addr: addr, Quantity: quantity, IsWrite: false, Args: nil, }) } else { regs, err = ms.handler.HandleInputRegisters( &InputRegistersRequest{ ClientAddr: clientAddr, ClientRole: clientRole, UnitId: req.unitId, Addr: addr, Quantity: quantity, }) } resCount = len(regs) // make sure the handler returned the expected number of items if err == nil && resCount != int(quantity) { ms.logger.Errorf("handler returned %v 16-bit values, " + "expected %v", resCount, quantity) err = ErrServerDeviceFailure break } if err != nil { break } // assemble a response PDU res = &pdu{ unitId: req.unitId, functionCode: req.functionCode, payload: []byte{0}, } // byte count (2 bytes per register) res.payload[0] = uint8(resCount * 2) // register values res.payload = append(res.payload, uint16sToBytes(BIG_ENDIAN, regs)...) case fcWriteSingleRegister: var value uint16 if len(req.payload) != 4 { err = ErrProtocolError break } // decode address and value fields addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2]) value = bytesToUint16(BIG_ENDIAN, req.payload[2:4]) // invoke the handler _, err = ms.handler.HandleHoldingRegisters( &HoldingRegistersRequest{ ClientAddr: clientAddr, ClientRole: clientRole, UnitId: req.unitId, Addr: addr, Quantity: 1, // request for a single register IsWrite: true, // request is a write Args: []uint16{value}, }) if err != nil { break } // assemble a response PDU res = &pdu{ unitId: req.unitId, functionCode: req.functionCode, } // echo the address and value in the response res.payload = append(res.payload, uint16ToBytes(BIG_ENDIAN, addr)...) res.payload = append(res.payload, uint16ToBytes(BIG_ENDIAN, value)...) case fcWriteMultipleRegisters: var expectedLen int if len(req.payload) < 6 { err = ErrProtocolError break } // decode address and quantity fields addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2]) quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4]) // ensure the reply never exceeds the maximum PDU length and we // never read past 0xffff if quantity > 0x007b || quantity == 0 { err = ErrProtocolError break } if uint32(addr) + uint32(quantity) - 1 > 0xffff { err = ErrIllegalDataAddress break } // validate the byte count field (2 bytes per register) expectedLen = int(quantity) * 2 if req.payload[4] != uint8(expectedLen) { err = ErrProtocolError break } // make sure we have enough bytes if len(req.payload) - 5 != expectedLen { err = ErrProtocolError break } // invoke the holding register handler _, err = ms.handler.HandleHoldingRegisters( &HoldingRegistersRequest{ ClientAddr: clientAddr, ClientRole: clientRole, UnitId: req.unitId, Addr: addr, Quantity: quantity, IsWrite: true, // this is a write request Args: bytesToUint16s(BIG_ENDIAN, req.payload[5:]), }) if err != nil { break } // assemble a response PDU res = &pdu{ unitId: req.unitId, functionCode: req.functionCode, } // echo the address and quantity in the response res.payload = append(res.payload, uint16ToBytes(BIG_ENDIAN, addr)...) res.payload = append(res.payload, uint16ToBytes(BIG_ENDIAN, quantity)...) default: res = &pdu{ // reply with the request target unit ID unitId: req.unitId, // set the error bit functionCode: (0x80 | req.functionCode), // set the exception code to illegal function to indicate that // the server does not know how to handle this function code. payload: []byte{exIllegalFunction}, } } // if there was no error processing the request but the response is nil // (which should never happen), emit a server failure exception code // and log an error if err == nil && res == nil { err = ErrServerDeviceFailure ms.logger.Errorf("internal server error (req: %v, res: %v, err: %v)", req, res, err) } // map go errors to modbus errors, unless the error is a protocol error, // in which case close the transport and return. if err != nil { if err == ErrProtocolError { ms.logger.Warningf( "protocol error, closing link (client address: '%s')", clientAddr) t.Close() return } else { res = &pdu{ unitId: req.unitId, functionCode: (0x80 | req.functionCode), payload: []byte{mapErrorToExceptionCode(err)}, } } } // write the response to the transport err = t.WriteResponse(res) if err != nil { ms.logger.Warningf("failed to write response: %v", err) } // avoid holding on to stale data req = nil res = nil } // never reached return } // startTLS performs a full TLS handshake (with client authentication) on tcpSock // and returns a 'wrapped' clear-text socket suitable for use by the TCP transport. func (ms *ModbusServer) startTLS(tcpSock net.Conn) ( tlsSock *tls.Conn, clientRole string, err error) { var connState tls.ConnectionState // set a 30s timeout for the TLS handshake to complete err = tcpSock.SetDeadline(time.Now().Add(30 * time.Second)) if err != nil { return } // start TLS negotiation over the raw TCP connection tlsSock = tls.Server(tcpSock, &tls.Config{ Certificates: []tls.Certificate{ *ms.conf.TLSServerCert, }, ClientCAs: ms.conf.TLSClientCAs, // require a valid (verified) certificate from the client // (see R-06, R-08 and R-10 of the MBAPS spec) ClientAuth: tls.RequireAndVerifyClientCert, // mandate TLSv1.2 or higher (see R-01 of the MBAPS spec) MinVersion: tls.VersionTLS12, }) // complete the full TLS handshake (with client cert validation) err = tlsSock.Handshake() if err != nil { return } // look for and extract the client's role, if any connState = tlsSock.ConnectionState() if len(connState.PeerCertificates) == 0 { err = errors.New("no client certificate received") return } // From the tls.ConnectionState doc: // "The first element is the leaf certificate that the connection is // verified against." clientRole = ms.extractRole(connState.PeerCertificates[0]) return } // extractRole looks for Modbus Role extensions in a certificate and returns the // role as a string. // If no role extension is found, a nil string is returned (R-23). // If multiple or invalid role extensions are found, a nil string is returned (R-65, R-22). func (ms *ModbusServer) extractRole(cert *x509.Certificate) (role string) { var err error var found bool var badCert bool // walk through all extensions looking for Modbus Role OIDs for _, ext := range cert.Extensions { if ext.Id.Equal(modbusRoleOID) { // there must be only one role extension per cert (R-65) if found { ms.logger.Warning("client certificate contains more than one role OIDs") badCert = true break } found = true // the role extension must use UTF8String encoding (R-22) // (the ASN1 tag for UTF8String is 0x0c) if len(ext.Value) < 2 || ext.Value[0] != 0x0c { badCert = true break } // extract the ASN1 string _, err = asn1.Unmarshal(ext.Value, &role) if err != nil { ms.logger.Warningf("failed to decode Modbus Role extension: %v", err) badCert = true break } } } // blank the role if we found more than one Role extension if badCert { role = "" } return }