Files
modbus/server.go
2025-11-07 13:53:18 +08:00

902 lines
25 KiB
Go

package modbus
import (
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"errors"
"fmt"
"log"
"net"
"strings"
"sync"
"time"
)
// Modbus Role PEM OID (see R-21 of the MBAPS spec)
var modbusRoleOID asn1.ObjectIdentifier = asn1.ObjectIdentifier{
1, 3, 6, 1, 4, 1, 50316, 802, 1,
}
// Server configuration object.
type ServerConfiguration struct {
// URL defines where to listen at e.g. tcp://[::]:502
URL string
// Timeout sets the idle session timeout (client connections will
// be closed if idle for this long)
Timeout time.Duration
// MaxClients sets the maximum number of concurrent client connections
MaxClients uint
// TLSServerCert sets the server-side TLS key pair (tcp+tls only)
TLSServerCert *tls.Certificate
// TLSClientCAs sets the list of CA certificates used to authenticate
// client connections (tcp+tls only). Leaf (i.e. client) certificates can
// also be used in case of self-signed certs, or if cert pinning is required.
TLSClientCAs *x509.CertPool
// Logger provides a custom sink for log messages.
// If nil, messages will be written to stdout.
Logger *log.Logger
}
// Request object passed to the coil handler.
type CoilsRequest struct {
ClientAddr string // the source (client) IP address
ClientRole string // the client role as encoded in the client certificate (tcp+tls only)
UnitId uint8 // the requested unit id (slave id)
Addr uint16 // the base coil address requested
Quantity uint16 // the number of consecutive coils covered by this request
// (first address: Addr, last address: Addr + Quantity - 1)
IsWrite bool // true if the request is a write, false if a read
Args []bool // a slice of bool values of the coils to be set, ordered
// from Addr to Addr + Quantity - 1 (for writes only)
}
// Request object passed to the discrete input handler.
type DiscreteInputsRequest struct {
ClientAddr string // the source (client) IP address
ClientRole string // the client role as encoded in the client certificate (tcp+tls only)
UnitId uint8 // the requested unit id (slave id)
Addr uint16 // the base discrete input address requested
Quantity uint16 // the number of consecutive discrete inputs covered by this request
}
// Request object passed to the holding register handler.
type HoldingRegistersRequest struct {
ClientAddr string // the source (client) IP address
ClientRole string // the client role as encoded in the client certificate (tcp+tls only)
UnitId uint8 // the requested unit id (slave id)
Addr uint16 // the base register address requested
Quantity uint16 // the number of consecutive registers covered by this request
IsWrite bool // true if the request is a write, false if a read
Args []uint16 // a slice of register values to be set, ordered from
// Addr to Addr + Quantity - 1 (for writes only)
}
// Request object passed to the input register handler.
type InputRegistersRequest struct {
ClientAddr string // the source (client) IP address
ClientRole string // the client role as encoded in the client certificate (tcp+tls only)
UnitId uint8 // the requested unit id (slave id)
Addr uint16 // the base register address requested
Quantity uint16 // the number of consecutive registers covered by this request
}
// The RequestHandler interface should be implemented by the handler
// object passed to NewServer (see reqHandler in NewServer()).
// After decoding and validating an incoming request, the server will
// invoke the appropriate handler function, depending on the function code
// of the request.
type RequestHandler interface {
// HandleCoils handles the read coils (0x01), write single coil (0x05)
// and write multiple coils (0x0f) function codes.
// A CoilsRequest object is passed to the handler (see above).
//
// Expected return values:
// - res: a slice of bools containing the coil values to be sent to back
// to the client (only sent for reads),
// - err: either nil if no error occurred, a modbus error (see
// mapErrorToExceptionCode() in modbus.go for a complete list),
// or any other error.
// If nil, a positive modbus response is sent back to the client
// along with the returned data.
// If non-nil, a negative modbus response is sent back, with the
// exception code set depending on the error
// (again, see mapErrorToExceptionCode()).
HandleCoils (req *CoilsRequest) (res []bool, err error)
// HandleDiscreteInputs handles the read discrete inputs (0x02) function code.
// A DiscreteInputsRequest oibject is passed to the handler (see above).
//
// Expected return values:
// - res: a slice of bools containing the discrete input values to be
// sent back to the client,
// - err: either nil if no error occurred, a modbus error (see
// mapErrorToExceptionCode() in modbus.go for a complete list),
// or any other error.
HandleDiscreteInputs (req *DiscreteInputsRequest) (res []bool, err error)
// HandleHoldingRegisters handles the read holding registers (0x03),
// write single register (0x06) and write multiple registers (0x10).
// A HoldingRegistersRequest object is passed to the handler (see above).
//
// Expected return values:
// - res: a slice of uint16 containing the register values to be sent
// to back to the client (only sent for reads),
// - err: either nil if no error occurred, a modbus error (see
// mapErrorToExceptionCode() in modbus.go for a complete list),
// or any other error.
HandleHoldingRegisters (req *HoldingRegistersRequest) (res []uint16, err error)
// HandleInputRegisters handles the read input registers (0x04) function code.
// An InputRegistersRequest object is passed to the handler (see above).
//
// Expected return values:
// - res: a slice of uint16 containing the register values to be sent
// back to the client,
// - err: either nil if no error occurred, a modbus error (see
// mapErrorToExceptionCode() in modbus.go for a complete list),
// or any other error.
HandleInputRegisters (req *InputRegistersRequest) (res []uint16, err error)
}
// Modbus server object.
type ModbusServer struct {
conf ServerConfiguration
logger *logger
lock sync.Mutex
started bool
handler RequestHandler
tcpListener net.Listener
tcpClients []net.Conn
transportType transportType
}
// Returns a new modbus server.
// reqHandler should be a user-provided handler object satisfying the RequestHandler
// interface.
func NewServer(conf *ServerConfiguration, reqHandler RequestHandler) (
ms *ModbusServer, err error) {
var serverType string
var splitURL []string
ms = &ModbusServer{
conf: *conf,
handler: reqHandler,
}
splitURL = strings.SplitN(ms.conf.URL, "://", 2)
if len(splitURL) == 2 {
serverType = splitURL[0]
ms.conf.URL = splitURL[1]
}
ms.logger = newLogger(
fmt.Sprintf("modbus-server(%s)", ms.conf.URL), ms.conf.Logger)
if ms.conf.URL == "" {
ms.logger.Errorf("missing host part in URL '%s'", conf.URL)
err = ErrConfigurationError
return
}
switch serverType {
case "tcp":
if ms.conf.Timeout == 0 {
ms.conf.Timeout = 120 * time.Second
}
if ms.conf.MaxClients == 0 {
ms.conf.MaxClients = 10
}
ms.transportType = modbusTCP
case "tcp+tls":
if ms.conf.Timeout == 0 {
ms.conf.Timeout = 120 * time.Second
}
if ms.conf.MaxClients == 0 {
ms.conf.MaxClients = 10
}
// expect a server-side certificate
if ms.conf.TLSServerCert == nil {
ms.logger.Errorf("missing server certificate")
err = ErrConfigurationError
return
}
// expect a CertPool object containing at least 1 CA or
// leaf certificate to validate client-side certificates
if ms.conf.TLSClientCAs == nil {
ms.logger.Errorf("missing CA/client certificates")
err = ErrConfigurationError
return
}
ms.transportType = modbusTCPOverTLS
default:
err = ErrConfigurationError
return
}
return
}
// Starts accepting client connections.
func (ms *ModbusServer) Start() (err error) {
ms.lock.Lock()
defer ms.lock.Unlock()
if ms.started {
return
}
switch ms.transportType {
case modbusTCP, modbusTCPOverTLS:
// bind to a TCP socket
ms.tcpListener, err = net.Listen("tcp", ms.conf.URL)
if err != nil {
return
}
// accept client connections in a goroutine
go ms.acceptTCPClients()
default:
err = ErrConfigurationError
return
}
ms.started = true
return
}
// Stops accepting new client connections and closes any active session.
func (ms *ModbusServer) Stop() (err error) {
ms.lock.Lock()
defer ms.lock.Unlock()
if !ms.started {
return
}
ms.started = false
if ms.transportType == modbusTCP || ms.transportType == modbusTCPOverTLS {
// close the server socket if we're listening over TCP
err = ms.tcpListener.Close()
// close all active TCP clients
for _, sock := range ms.tcpClients{
sock.Close()
}
}
return
}
// Accepts new client connections if the configured connection limit allows it.
// Each connection is served from a dedicated goroutine to allow for concurrent
// connections.
func (ms *ModbusServer) acceptTCPClients() {
var sock net.Conn
var err error
var accepted bool
for {
sock, err = ms.tcpListener.Accept()
if err != nil {
// if the server socket has just been closed, return here as
// this goroutine isn't going to see any new client connection
if errors.Is(err, net.ErrClosed) {
return
}
ms.logger.Warningf("failed to accept client connection: %v", err)
continue
}
ms.lock.Lock()
// apply a connection limit
if ms.started && uint(len(ms.tcpClients)) < ms.conf.MaxClients {
accepted = true
// add the new client connection to the pool
ms.tcpClients = append(ms.tcpClients, sock)
} else {
accepted = false
}
ms.lock.Unlock()
if accepted {
// spin a client handler goroutine to serve the new client
go ms.handleTCPClient(sock)
} else {
ms.logger.Warningf("max. number of concurrent connections " +
"reached, rejecting %v", sock.RemoteAddr())
// discard the connection
sock.Close()
}
}
// never reached
return
}
// Handles a TCP client connection.
// Once handleTransport() returns (i.e. the connection has either closed, timed
// out, or an unrecoverable error happened), the TCP socket is closed and removed
// from the list of active client connections.
func (ms *ModbusServer) handleTCPClient(sock net.Conn) {
var err error
var clientRole string
var tlsSock net.Conn
switch ms.transportType {
case modbusTCP:
// serve modbus requests over the raw TCP connection
ms.handleTransport(
newTCPTransport(sock, ms.conf.Timeout, ms.conf.Logger),
sock.RemoteAddr().String(), "")
case modbusTCPOverTLS:
// start TLS negotiation over the raw TCP connection
tlsSock, clientRole, err = ms.startTLS(sock)
if err != nil {
ms.logger.Warningf("TLS handshake with %s failed: %v",
sock.RemoteAddr().String(), err)
} else {
// serve modbus requests over the TLS tunnel
ms.handleTransport(
newTCPTransport(tlsSock, ms.conf.Timeout, ms.conf.Logger),
sock.RemoteAddr().String(), clientRole)
}
default:
ms.logger.Errorf("unimplemented transport type %v", ms.transportType)
}
// once done, remove our connection from the list of active client conns
ms.lock.Lock()
for i := range ms.tcpClients {
if ms.tcpClients[i] == sock {
ms.tcpClients[i] = ms.tcpClients[len(ms.tcpClients)-1]
ms.tcpClients = ms.tcpClients[:len(ms.tcpClients)-1]
break
}
}
ms.lock.Unlock()
// close the connection
sock.Close()
return
}
// For each request read from the transport, performs decoding and validation,
// calls the user-provided handler, then encodes and writes the response
// to the transport.
func (ms *ModbusServer) handleTransport(t transport, clientAddr string, clientRole string) {
var req *pdu
var res *pdu
var err error
var addr uint16
var quantity uint16
for {
req, err = t.ReadRequest()
if err != nil {
return
}
switch req.functionCode {
case fcReadCoils, fcReadDiscreteInputs:
var coils []bool
var resCount int
if len(req.payload) != 4 {
err = ErrProtocolError
break
}
// decode address and quantity fields
addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2])
quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4])
// ensure the reply never exceeds the maximum PDU length and we
// never read past 0xffff
if quantity > 2000 || quantity == 0 {
err = ErrProtocolError
break
}
if uint32(addr) + uint32(quantity) - 1 > 0xffff {
err = ErrIllegalDataAddress
break
}
// invoke the appropriate handler
if req.functionCode == fcReadCoils {
coils, err = ms.handler.HandleCoils(&CoilsRequest{
ClientAddr: clientAddr,
ClientRole: clientRole,
UnitId: req.unitId,
Addr: addr,
Quantity: quantity,
IsWrite: false,
Args: nil,
})
} else {
coils, err = ms.handler.HandleDiscreteInputs(
&DiscreteInputsRequest{
ClientAddr: clientAddr,
ClientRole: clientRole,
UnitId: req.unitId,
Addr: addr,
Quantity: quantity,
})
}
resCount = len(coils)
// make sure the handler returned the expected number of items
if err == nil && resCount != int(quantity) {
ms.logger.Errorf("handler returned %v bools, " +
"expected %v", resCount, quantity)
err = ErrServerDeviceFailure
break
}
if err != nil {
break
}
// assemble a response PDU
res = &pdu{
unitId: req.unitId,
functionCode: req.functionCode,
payload: []byte{0},
}
// byte count (1 byte for 8 coils)
res.payload[0] = uint8(resCount / 8)
if resCount % 8 != 0 {
res.payload[0]++
}
// coil values
res.payload = append(res.payload, encodeBools(coils)...)
case fcWriteSingleCoil:
if len(req.payload) != 4 {
err = ErrProtocolError
break
}
// decode the address field
addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2])
// validate the value field (should be either 0xff00 or 0x0000)
if ((req.payload[2] != 0xff && req.payload[2] != 0x00) ||
req.payload[3] != 0x00) {
err = ErrProtocolError
break
}
// invoke the coil handler
_, err = ms.handler.HandleCoils(&CoilsRequest{
ClientAddr: clientAddr,
ClientRole: clientRole,
UnitId: req.unitId,
Addr: addr,
Quantity: 1, // request for a single coil
IsWrite: true, // this is a write request
Args: []bool{(req.payload[2] == 0xff)},
})
if err != nil {
break
}
// assemble a response PDU
res = &pdu{
unitId: req.unitId,
functionCode: req.functionCode,
}
// echo the address and value in the response
res.payload = append(res.payload,
uint16ToBytes(BIG_ENDIAN, addr)...)
res.payload = append(res.payload,
req.payload[2], req.payload[3])
case fcWriteMultipleCoils:
var expectedLen int
if len(req.payload) < 6 {
err = ErrProtocolError
break
}
// decode address and quantity fields
addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2])
quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4])
// ensure the reply never exceeds the maximum PDU length and we
// never read past 0xffff
if quantity > 0x7b0 || quantity == 0 {
err = ErrProtocolError
break
}
if uint32(addr) + uint32(quantity) - 1 > 0xffff {
err = ErrIllegalDataAddress
break
}
// validate the byte count field (1 byte for 8 coils)
expectedLen = int(quantity) / 8
if quantity % 8 != 0 {
expectedLen++
}
if req.payload[4] != uint8(expectedLen) {
err = ErrProtocolError
break
}
// make sure we have enough bytes
if len(req.payload) - 5 != expectedLen {
err = ErrProtocolError
break
}
// invoke the coil handler
_, err = ms.handler.HandleCoils(&CoilsRequest{
ClientAddr: clientAddr,
ClientRole: clientRole,
UnitId: req.unitId,
Addr: addr,
Quantity: quantity,
IsWrite: true, // this is a write request
Args: decodeBools(quantity, req.payload[5:]),
})
if err != nil {
break
}
// assemble a response PDU
res = &pdu{
unitId: req.unitId,
functionCode: req.functionCode,
}
// echo the address and quantity in the response
res.payload = append(res.payload,
uint16ToBytes(BIG_ENDIAN, addr)...)
res.payload = append(res.payload,
uint16ToBytes(BIG_ENDIAN, quantity)...)
case fcReadHoldingRegisters, fcReadInputRegisters:
var regs []uint16
var resCount int
if len(req.payload) != 4 {
err = ErrProtocolError
break
}
// decode address and quantity fields
addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2])
quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4])
// ensure the reply never exceeds the maximum PDU length and we
// never read past 0xffff
if quantity > 0x007d || quantity == 0 {
err = ErrProtocolError
break
}
if uint32(addr) + uint32(quantity) - 1 > 0xffff {
err = ErrIllegalDataAddress
break
}
// invoke the appropriate handler
if req.functionCode == fcReadHoldingRegisters {
regs, err = ms.handler.HandleHoldingRegisters(
&HoldingRegistersRequest{
ClientAddr: clientAddr,
ClientRole: clientRole,
UnitId: req.unitId,
Addr: addr,
Quantity: quantity,
IsWrite: false,
Args: nil,
})
} else {
regs, err = ms.handler.HandleInputRegisters(
&InputRegistersRequest{
ClientAddr: clientAddr,
ClientRole: clientRole,
UnitId: req.unitId,
Addr: addr,
Quantity: quantity,
})
}
resCount = len(regs)
// make sure the handler returned the expected number of items
if err == nil && resCount != int(quantity) {
ms.logger.Errorf("handler returned %v 16-bit values, " +
"expected %v", resCount, quantity)
err = ErrServerDeviceFailure
break
}
if err != nil {
break
}
// assemble a response PDU
res = &pdu{
unitId: req.unitId,
functionCode: req.functionCode,
payload: []byte{0},
}
// byte count (2 bytes per register)
res.payload[0] = uint8(resCount * 2)
// register values
res.payload = append(res.payload,
uint16sToBytes(BIG_ENDIAN, regs)...)
case fcWriteSingleRegister:
var value uint16
if len(req.payload) != 4 {
err = ErrProtocolError
break
}
// decode address and value fields
addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2])
value = bytesToUint16(BIG_ENDIAN, req.payload[2:4])
// invoke the handler
_, err = ms.handler.HandleHoldingRegisters(
&HoldingRegistersRequest{
ClientAddr: clientAddr,
ClientRole: clientRole,
UnitId: req.unitId,
Addr: addr,
Quantity: 1, // request for a single register
IsWrite: true, // request is a write
Args: []uint16{value},
})
if err != nil {
break
}
// assemble a response PDU
res = &pdu{
unitId: req.unitId,
functionCode: req.functionCode,
}
// echo the address and value in the response
res.payload = append(res.payload,
uint16ToBytes(BIG_ENDIAN, addr)...)
res.payload = append(res.payload,
uint16ToBytes(BIG_ENDIAN, value)...)
case fcWriteMultipleRegisters:
var expectedLen int
if len(req.payload) < 6 {
err = ErrProtocolError
break
}
// decode address and quantity fields
addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2])
quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4])
// ensure the reply never exceeds the maximum PDU length and we
// never read past 0xffff
if quantity > 0x007b || quantity == 0 {
err = ErrProtocolError
break
}
if uint32(addr) + uint32(quantity) - 1 > 0xffff {
err = ErrIllegalDataAddress
break
}
// validate the byte count field (2 bytes per register)
expectedLen = int(quantity) * 2
if req.payload[4] != uint8(expectedLen) {
err = ErrProtocolError
break
}
// make sure we have enough bytes
if len(req.payload) - 5 != expectedLen {
err = ErrProtocolError
break
}
// invoke the holding register handler
_, err = ms.handler.HandleHoldingRegisters(
&HoldingRegistersRequest{
ClientAddr: clientAddr,
ClientRole: clientRole,
UnitId: req.unitId,
Addr: addr,
Quantity: quantity,
IsWrite: true, // this is a write request
Args: bytesToUint16s(BIG_ENDIAN, req.payload[5:]),
})
if err != nil {
break
}
// assemble a response PDU
res = &pdu{
unitId: req.unitId,
functionCode: req.functionCode,
}
// echo the address and quantity in the response
res.payload = append(res.payload,
uint16ToBytes(BIG_ENDIAN, addr)...)
res.payload = append(res.payload,
uint16ToBytes(BIG_ENDIAN, quantity)...)
default:
res = &pdu{
// reply with the request target unit ID
unitId: req.unitId,
// set the error bit
functionCode: (0x80 | req.functionCode),
// set the exception code to illegal function to indicate that
// the server does not know how to handle this function code.
payload: []byte{exIllegalFunction},
}
}
// if there was no error processing the request but the response is nil
// (which should never happen), emit a server failure exception code
// and log an error
if err == nil && res == nil {
err = ErrServerDeviceFailure
ms.logger.Errorf("internal server error (req: %v, res: %v, err: %v)",
req, res, err)
}
// map go errors to modbus errors, unless the error is a protocol error,
// in which case close the transport and return.
if err != nil {
if err == ErrProtocolError {
ms.logger.Warningf(
"protocol error, closing link (client address: '%s')",
clientAddr)
t.Close()
return
} else {
res = &pdu{
unitId: req.unitId,
functionCode: (0x80 | req.functionCode),
payload: []byte{mapErrorToExceptionCode(err)},
}
}
}
// write the response to the transport
err = t.WriteResponse(res)
if err != nil {
ms.logger.Warningf("failed to write response: %v", err)
}
// avoid holding on to stale data
req = nil
res = nil
}
// never reached
return
}
// startTLS performs a full TLS handshake (with client authentication) on tcpSock
// and returns a 'wrapped' clear-text socket suitable for use by the TCP transport.
func (ms *ModbusServer) startTLS(tcpSock net.Conn) (
tlsSock *tls.Conn, clientRole string, err error) {
var connState tls.ConnectionState
// set a 30s timeout for the TLS handshake to complete
err = tcpSock.SetDeadline(time.Now().Add(30 * time.Second))
if err != nil {
return
}
// start TLS negotiation over the raw TCP connection
tlsSock = tls.Server(tcpSock, &tls.Config{
Certificates: []tls.Certificate{
*ms.conf.TLSServerCert,
},
ClientCAs: ms.conf.TLSClientCAs,
// require a valid (verified) certificate from the client
// (see R-06, R-08 and R-10 of the MBAPS spec)
ClientAuth: tls.RequireAndVerifyClientCert,
// mandate TLSv1.2 or higher (see R-01 of the MBAPS spec)
MinVersion: tls.VersionTLS12,
})
// complete the full TLS handshake (with client cert validation)
err = tlsSock.Handshake()
if err != nil {
return
}
// look for and extract the client's role, if any
connState = tlsSock.ConnectionState()
if len(connState.PeerCertificates) == 0 {
err = errors.New("no client certificate received")
return
}
// From the tls.ConnectionState doc:
// "The first element is the leaf certificate that the connection is
// verified against."
clientRole = ms.extractRole(connState.PeerCertificates[0])
return
}
// extractRole looks for Modbus Role extensions in a certificate and returns the
// role as a string.
// If no role extension is found, a nil string is returned (R-23).
// If multiple or invalid role extensions are found, a nil string is returned (R-65, R-22).
func (ms *ModbusServer) extractRole(cert *x509.Certificate) (role string) {
var err error
var found bool
var badCert bool
// walk through all extensions looking for Modbus Role OIDs
for _, ext := range cert.Extensions {
if ext.Id.Equal(modbusRoleOID) {
// there must be only one role extension per cert (R-65)
if found {
ms.logger.Warning("client certificate contains more than one role OIDs")
badCert = true
break
}
found = true
// the role extension must use UTF8String encoding (R-22)
// (the ASN1 tag for UTF8String is 0x0c)
if len(ext.Value) < 2 || ext.Value[0] != 0x0c {
badCert = true
break
}
// extract the ASN1 string
_, err = asn1.Unmarshal(ext.Value, &role)
if err != nil {
ms.logger.Warningf("failed to decode Modbus Role extension: %v", err)
badCert = true
break
}
}
}
// blank the role if we found more than one Role extension
if badCert {
role = ""
}
return
}