first init
This commit is contained in:
901
server.go
Normal file
901
server.go
Normal file
@@ -0,0 +1,901 @@
|
||||
package modbus
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/asn1"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Modbus Role PEM OID (see R-21 of the MBAPS spec)
|
||||
var modbusRoleOID asn1.ObjectIdentifier = asn1.ObjectIdentifier{
|
||||
1, 3, 6, 1, 4, 1, 50316, 802, 1,
|
||||
}
|
||||
|
||||
// Server configuration object.
|
||||
type ServerConfiguration struct {
|
||||
// URL defines where to listen at e.g. tcp://[::]:502
|
||||
URL string
|
||||
// Timeout sets the idle session timeout (client connections will
|
||||
// be closed if idle for this long)
|
||||
Timeout time.Duration
|
||||
// MaxClients sets the maximum number of concurrent client connections
|
||||
MaxClients uint
|
||||
// TLSServerCert sets the server-side TLS key pair (tcp+tls only)
|
||||
TLSServerCert *tls.Certificate
|
||||
// TLSClientCAs sets the list of CA certificates used to authenticate
|
||||
// client connections (tcp+tls only). Leaf (i.e. client) certificates can
|
||||
// also be used in case of self-signed certs, or if cert pinning is required.
|
||||
TLSClientCAs *x509.CertPool
|
||||
// Logger provides a custom sink for log messages.
|
||||
// If nil, messages will be written to stdout.
|
||||
Logger *log.Logger
|
||||
}
|
||||
|
||||
// Request object passed to the coil handler.
|
||||
type CoilsRequest struct {
|
||||
ClientAddr string // the source (client) IP address
|
||||
ClientRole string // the client role as encoded in the client certificate (tcp+tls only)
|
||||
UnitId uint8 // the requested unit id (slave id)
|
||||
Addr uint16 // the base coil address requested
|
||||
Quantity uint16 // the number of consecutive coils covered by this request
|
||||
// (first address: Addr, last address: Addr + Quantity - 1)
|
||||
IsWrite bool // true if the request is a write, false if a read
|
||||
Args []bool // a slice of bool values of the coils to be set, ordered
|
||||
// from Addr to Addr + Quantity - 1 (for writes only)
|
||||
}
|
||||
|
||||
// Request object passed to the discrete input handler.
|
||||
type DiscreteInputsRequest struct {
|
||||
ClientAddr string // the source (client) IP address
|
||||
ClientRole string // the client role as encoded in the client certificate (tcp+tls only)
|
||||
UnitId uint8 // the requested unit id (slave id)
|
||||
Addr uint16 // the base discrete input address requested
|
||||
Quantity uint16 // the number of consecutive discrete inputs covered by this request
|
||||
}
|
||||
|
||||
// Request object passed to the holding register handler.
|
||||
type HoldingRegistersRequest struct {
|
||||
ClientAddr string // the source (client) IP address
|
||||
ClientRole string // the client role as encoded in the client certificate (tcp+tls only)
|
||||
UnitId uint8 // the requested unit id (slave id)
|
||||
Addr uint16 // the base register address requested
|
||||
Quantity uint16 // the number of consecutive registers covered by this request
|
||||
IsWrite bool // true if the request is a write, false if a read
|
||||
Args []uint16 // a slice of register values to be set, ordered from
|
||||
// Addr to Addr + Quantity - 1 (for writes only)
|
||||
}
|
||||
|
||||
// Request object passed to the input register handler.
|
||||
type InputRegistersRequest struct {
|
||||
ClientAddr string // the source (client) IP address
|
||||
ClientRole string // the client role as encoded in the client certificate (tcp+tls only)
|
||||
UnitId uint8 // the requested unit id (slave id)
|
||||
Addr uint16 // the base register address requested
|
||||
Quantity uint16 // the number of consecutive registers covered by this request
|
||||
}
|
||||
|
||||
// The RequestHandler interface should be implemented by the handler
|
||||
// object passed to NewServer (see reqHandler in NewServer()).
|
||||
// After decoding and validating an incoming request, the server will
|
||||
// invoke the appropriate handler function, depending on the function code
|
||||
// of the request.
|
||||
type RequestHandler interface {
|
||||
// HandleCoils handles the read coils (0x01), write single coil (0x05)
|
||||
// and write multiple coils (0x0f) function codes.
|
||||
// A CoilsRequest object is passed to the handler (see above).
|
||||
//
|
||||
// Expected return values:
|
||||
// - res: a slice of bools containing the coil values to be sent to back
|
||||
// to the client (only sent for reads),
|
||||
// - err: either nil if no error occurred, a modbus error (see
|
||||
// mapErrorToExceptionCode() in modbus.go for a complete list),
|
||||
// or any other error.
|
||||
// If nil, a positive modbus response is sent back to the client
|
||||
// along with the returned data.
|
||||
// If non-nil, a negative modbus response is sent back, with the
|
||||
// exception code set depending on the error
|
||||
// (again, see mapErrorToExceptionCode()).
|
||||
HandleCoils (req *CoilsRequest) (res []bool, err error)
|
||||
|
||||
// HandleDiscreteInputs handles the read discrete inputs (0x02) function code.
|
||||
// A DiscreteInputsRequest oibject is passed to the handler (see above).
|
||||
//
|
||||
// Expected return values:
|
||||
// - res: a slice of bools containing the discrete input values to be
|
||||
// sent back to the client,
|
||||
// - err: either nil if no error occurred, a modbus error (see
|
||||
// mapErrorToExceptionCode() in modbus.go for a complete list),
|
||||
// or any other error.
|
||||
HandleDiscreteInputs (req *DiscreteInputsRequest) (res []bool, err error)
|
||||
|
||||
// HandleHoldingRegisters handles the read holding registers (0x03),
|
||||
// write single register (0x06) and write multiple registers (0x10).
|
||||
// A HoldingRegistersRequest object is passed to the handler (see above).
|
||||
//
|
||||
// Expected return values:
|
||||
// - res: a slice of uint16 containing the register values to be sent
|
||||
// to back to the client (only sent for reads),
|
||||
// - err: either nil if no error occurred, a modbus error (see
|
||||
// mapErrorToExceptionCode() in modbus.go for a complete list),
|
||||
// or any other error.
|
||||
HandleHoldingRegisters (req *HoldingRegistersRequest) (res []uint16, err error)
|
||||
|
||||
// HandleInputRegisters handles the read input registers (0x04) function code.
|
||||
// An InputRegistersRequest object is passed to the handler (see above).
|
||||
//
|
||||
// Expected return values:
|
||||
// - res: a slice of uint16 containing the register values to be sent
|
||||
// back to the client,
|
||||
// - err: either nil if no error occurred, a modbus error (see
|
||||
// mapErrorToExceptionCode() in modbus.go for a complete list),
|
||||
// or any other error.
|
||||
HandleInputRegisters (req *InputRegistersRequest) (res []uint16, err error)
|
||||
}
|
||||
|
||||
// Modbus server object.
|
||||
type ModbusServer struct {
|
||||
conf ServerConfiguration
|
||||
logger *logger
|
||||
lock sync.Mutex
|
||||
started bool
|
||||
handler RequestHandler
|
||||
tcpListener net.Listener
|
||||
tcpClients []net.Conn
|
||||
transportType transportType
|
||||
}
|
||||
|
||||
// Returns a new modbus server.
|
||||
// reqHandler should be a user-provided handler object satisfying the RequestHandler
|
||||
// interface.
|
||||
func NewServer(conf *ServerConfiguration, reqHandler RequestHandler) (
|
||||
ms *ModbusServer, err error) {
|
||||
var serverType string
|
||||
var splitURL []string
|
||||
|
||||
ms = &ModbusServer{
|
||||
conf: *conf,
|
||||
handler: reqHandler,
|
||||
}
|
||||
|
||||
splitURL = strings.SplitN(ms.conf.URL, "://", 2)
|
||||
if len(splitURL) == 2 {
|
||||
serverType = splitURL[0]
|
||||
ms.conf.URL = splitURL[1]
|
||||
}
|
||||
|
||||
ms.logger = newLogger(
|
||||
fmt.Sprintf("modbus-server(%s)", ms.conf.URL), ms.conf.Logger)
|
||||
|
||||
if ms.conf.URL == "" {
|
||||
ms.logger.Errorf("missing host part in URL '%s'", conf.URL)
|
||||
err = ErrConfigurationError
|
||||
return
|
||||
}
|
||||
|
||||
switch serverType {
|
||||
case "tcp":
|
||||
if ms.conf.Timeout == 0 {
|
||||
ms.conf.Timeout = 120 * time.Second
|
||||
}
|
||||
|
||||
if ms.conf.MaxClients == 0 {
|
||||
ms.conf.MaxClients = 10
|
||||
}
|
||||
|
||||
ms.transportType = modbusTCP
|
||||
|
||||
case "tcp+tls":
|
||||
if ms.conf.Timeout == 0 {
|
||||
ms.conf.Timeout = 120 * time.Second
|
||||
}
|
||||
|
||||
if ms.conf.MaxClients == 0 {
|
||||
ms.conf.MaxClients = 10
|
||||
}
|
||||
|
||||
// expect a server-side certificate
|
||||
if ms.conf.TLSServerCert == nil {
|
||||
ms.logger.Errorf("missing server certificate")
|
||||
err = ErrConfigurationError
|
||||
return
|
||||
}
|
||||
|
||||
// expect a CertPool object containing at least 1 CA or
|
||||
// leaf certificate to validate client-side certificates
|
||||
if ms.conf.TLSClientCAs == nil {
|
||||
ms.logger.Errorf("missing CA/client certificates")
|
||||
err = ErrConfigurationError
|
||||
return
|
||||
}
|
||||
|
||||
ms.transportType = modbusTCPOverTLS
|
||||
|
||||
default:
|
||||
err = ErrConfigurationError
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Starts accepting client connections.
|
||||
func (ms *ModbusServer) Start() (err error) {
|
||||
ms.lock.Lock()
|
||||
defer ms.lock.Unlock()
|
||||
|
||||
if ms.started {
|
||||
return
|
||||
}
|
||||
|
||||
switch ms.transportType {
|
||||
case modbusTCP, modbusTCPOverTLS:
|
||||
// bind to a TCP socket
|
||||
ms.tcpListener, err = net.Listen("tcp", ms.conf.URL)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// accept client connections in a goroutine
|
||||
go ms.acceptTCPClients()
|
||||
|
||||
default:
|
||||
err = ErrConfigurationError
|
||||
return
|
||||
}
|
||||
|
||||
ms.started = true
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Stops accepting new client connections and closes any active session.
|
||||
func (ms *ModbusServer) Stop() (err error) {
|
||||
ms.lock.Lock()
|
||||
defer ms.lock.Unlock()
|
||||
|
||||
if !ms.started {
|
||||
return
|
||||
}
|
||||
|
||||
ms.started = false
|
||||
|
||||
if ms.transportType == modbusTCP || ms.transportType == modbusTCPOverTLS {
|
||||
// close the server socket if we're listening over TCP
|
||||
err = ms.tcpListener.Close()
|
||||
|
||||
// close all active TCP clients
|
||||
for _, sock := range ms.tcpClients{
|
||||
sock.Close()
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Accepts new client connections if the configured connection limit allows it.
|
||||
// Each connection is served from a dedicated goroutine to allow for concurrent
|
||||
// connections.
|
||||
func (ms *ModbusServer) acceptTCPClients() {
|
||||
var sock net.Conn
|
||||
var err error
|
||||
var accepted bool
|
||||
|
||||
for {
|
||||
sock, err = ms.tcpListener.Accept()
|
||||
if err != nil {
|
||||
// if the server socket has just been closed, return here as
|
||||
// this goroutine isn't going to see any new client connection
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
ms.logger.Warningf("failed to accept client connection: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
ms.lock.Lock()
|
||||
// apply a connection limit
|
||||
if ms.started && uint(len(ms.tcpClients)) < ms.conf.MaxClients {
|
||||
accepted = true
|
||||
// add the new client connection to the pool
|
||||
ms.tcpClients = append(ms.tcpClients, sock)
|
||||
} else {
|
||||
accepted = false
|
||||
}
|
||||
ms.lock.Unlock()
|
||||
|
||||
if accepted {
|
||||
// spin a client handler goroutine to serve the new client
|
||||
go ms.handleTCPClient(sock)
|
||||
} else {
|
||||
ms.logger.Warningf("max. number of concurrent connections " +
|
||||
"reached, rejecting %v", sock.RemoteAddr())
|
||||
// discard the connection
|
||||
sock.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// never reached
|
||||
return
|
||||
}
|
||||
|
||||
// Handles a TCP client connection.
|
||||
// Once handleTransport() returns (i.e. the connection has either closed, timed
|
||||
// out, or an unrecoverable error happened), the TCP socket is closed and removed
|
||||
// from the list of active client connections.
|
||||
func (ms *ModbusServer) handleTCPClient(sock net.Conn) {
|
||||
var err error
|
||||
var clientRole string
|
||||
var tlsSock net.Conn
|
||||
|
||||
switch ms.transportType {
|
||||
case modbusTCP:
|
||||
// serve modbus requests over the raw TCP connection
|
||||
ms.handleTransport(
|
||||
newTCPTransport(sock, ms.conf.Timeout, ms.conf.Logger),
|
||||
sock.RemoteAddr().String(), "")
|
||||
|
||||
case modbusTCPOverTLS:
|
||||
// start TLS negotiation over the raw TCP connection
|
||||
tlsSock, clientRole, err = ms.startTLS(sock)
|
||||
if err != nil {
|
||||
ms.logger.Warningf("TLS handshake with %s failed: %v",
|
||||
sock.RemoteAddr().String(), err)
|
||||
} else {
|
||||
// serve modbus requests over the TLS tunnel
|
||||
ms.handleTransport(
|
||||
newTCPTransport(tlsSock, ms.conf.Timeout, ms.conf.Logger),
|
||||
sock.RemoteAddr().String(), clientRole)
|
||||
}
|
||||
|
||||
default:
|
||||
ms.logger.Errorf("unimplemented transport type %v", ms.transportType)
|
||||
}
|
||||
|
||||
// once done, remove our connection from the list of active client conns
|
||||
ms.lock.Lock()
|
||||
for i := range ms.tcpClients {
|
||||
if ms.tcpClients[i] == sock {
|
||||
ms.tcpClients[i] = ms.tcpClients[len(ms.tcpClients)-1]
|
||||
ms.tcpClients = ms.tcpClients[:len(ms.tcpClients)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
ms.lock.Unlock()
|
||||
|
||||
// close the connection
|
||||
sock.Close()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// For each request read from the transport, performs decoding and validation,
|
||||
// calls the user-provided handler, then encodes and writes the response
|
||||
// to the transport.
|
||||
func (ms *ModbusServer) handleTransport(t transport, clientAddr string, clientRole string) {
|
||||
var req *pdu
|
||||
var res *pdu
|
||||
var err error
|
||||
var addr uint16
|
||||
var quantity uint16
|
||||
|
||||
for {
|
||||
req, err = t.ReadRequest()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch req.functionCode {
|
||||
case fcReadCoils, fcReadDiscreteInputs:
|
||||
var coils []bool
|
||||
var resCount int
|
||||
|
||||
if len(req.payload) != 4 {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
|
||||
// decode address and quantity fields
|
||||
addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2])
|
||||
quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4])
|
||||
|
||||
// ensure the reply never exceeds the maximum PDU length and we
|
||||
// never read past 0xffff
|
||||
if quantity > 2000 || quantity == 0 {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
if uint32(addr) + uint32(quantity) - 1 > 0xffff {
|
||||
err = ErrIllegalDataAddress
|
||||
break
|
||||
}
|
||||
|
||||
// invoke the appropriate handler
|
||||
if req.functionCode == fcReadCoils {
|
||||
coils, err = ms.handler.HandleCoils(&CoilsRequest{
|
||||
ClientAddr: clientAddr,
|
||||
ClientRole: clientRole,
|
||||
UnitId: req.unitId,
|
||||
Addr: addr,
|
||||
Quantity: quantity,
|
||||
IsWrite: false,
|
||||
Args: nil,
|
||||
})
|
||||
} else {
|
||||
coils, err = ms.handler.HandleDiscreteInputs(
|
||||
&DiscreteInputsRequest{
|
||||
ClientAddr: clientAddr,
|
||||
ClientRole: clientRole,
|
||||
UnitId: req.unitId,
|
||||
Addr: addr,
|
||||
Quantity: quantity,
|
||||
})
|
||||
}
|
||||
resCount = len(coils)
|
||||
|
||||
// make sure the handler returned the expected number of items
|
||||
if err == nil && resCount != int(quantity) {
|
||||
ms.logger.Errorf("handler returned %v bools, " +
|
||||
"expected %v", resCount, quantity)
|
||||
err = ErrServerDeviceFailure
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// assemble a response PDU
|
||||
res = &pdu{
|
||||
unitId: req.unitId,
|
||||
functionCode: req.functionCode,
|
||||
payload: []byte{0},
|
||||
}
|
||||
|
||||
// byte count (1 byte for 8 coils)
|
||||
res.payload[0] = uint8(resCount / 8)
|
||||
if resCount % 8 != 0 {
|
||||
res.payload[0]++
|
||||
}
|
||||
|
||||
// coil values
|
||||
res.payload = append(res.payload, encodeBools(coils)...)
|
||||
|
||||
case fcWriteSingleCoil:
|
||||
if len(req.payload) != 4 {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
|
||||
// decode the address field
|
||||
addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2])
|
||||
|
||||
// validate the value field (should be either 0xff00 or 0x0000)
|
||||
if ((req.payload[2] != 0xff && req.payload[2] != 0x00) ||
|
||||
req.payload[3] != 0x00) {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
|
||||
// invoke the coil handler
|
||||
_, err = ms.handler.HandleCoils(&CoilsRequest{
|
||||
ClientAddr: clientAddr,
|
||||
ClientRole: clientRole,
|
||||
UnitId: req.unitId,
|
||||
Addr: addr,
|
||||
Quantity: 1, // request for a single coil
|
||||
IsWrite: true, // this is a write request
|
||||
Args: []bool{(req.payload[2] == 0xff)},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// assemble a response PDU
|
||||
res = &pdu{
|
||||
unitId: req.unitId,
|
||||
functionCode: req.functionCode,
|
||||
}
|
||||
|
||||
// echo the address and value in the response
|
||||
res.payload = append(res.payload,
|
||||
uint16ToBytes(BIG_ENDIAN, addr)...)
|
||||
res.payload = append(res.payload,
|
||||
req.payload[2], req.payload[3])
|
||||
|
||||
case fcWriteMultipleCoils:
|
||||
var expectedLen int
|
||||
|
||||
if len(req.payload) < 6 {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
|
||||
// decode address and quantity fields
|
||||
addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2])
|
||||
quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4])
|
||||
|
||||
// ensure the reply never exceeds the maximum PDU length and we
|
||||
// never read past 0xffff
|
||||
if quantity > 0x7b0 || quantity == 0 {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
if uint32(addr) + uint32(quantity) - 1 > 0xffff {
|
||||
err = ErrIllegalDataAddress
|
||||
break
|
||||
}
|
||||
|
||||
// validate the byte count field (1 byte for 8 coils)
|
||||
expectedLen = int(quantity) / 8
|
||||
if quantity % 8 != 0 {
|
||||
expectedLen++
|
||||
}
|
||||
|
||||
if req.payload[4] != uint8(expectedLen) {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
|
||||
// make sure we have enough bytes
|
||||
if len(req.payload) - 5 != expectedLen {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
|
||||
// invoke the coil handler
|
||||
_, err = ms.handler.HandleCoils(&CoilsRequest{
|
||||
ClientAddr: clientAddr,
|
||||
ClientRole: clientRole,
|
||||
UnitId: req.unitId,
|
||||
Addr: addr,
|
||||
Quantity: quantity,
|
||||
IsWrite: true, // this is a write request
|
||||
Args: decodeBools(quantity, req.payload[5:]),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// assemble a response PDU
|
||||
res = &pdu{
|
||||
unitId: req.unitId,
|
||||
functionCode: req.functionCode,
|
||||
}
|
||||
|
||||
// echo the address and quantity in the response
|
||||
res.payload = append(res.payload,
|
||||
uint16ToBytes(BIG_ENDIAN, addr)...)
|
||||
res.payload = append(res.payload,
|
||||
uint16ToBytes(BIG_ENDIAN, quantity)...)
|
||||
|
||||
case fcReadHoldingRegisters, fcReadInputRegisters:
|
||||
var regs []uint16
|
||||
var resCount int
|
||||
|
||||
if len(req.payload) != 4 {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
|
||||
// decode address and quantity fields
|
||||
addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2])
|
||||
quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4])
|
||||
|
||||
// ensure the reply never exceeds the maximum PDU length and we
|
||||
// never read past 0xffff
|
||||
if quantity > 0x007d || quantity == 0 {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
if uint32(addr) + uint32(quantity) - 1 > 0xffff {
|
||||
err = ErrIllegalDataAddress
|
||||
break
|
||||
}
|
||||
|
||||
// invoke the appropriate handler
|
||||
if req.functionCode == fcReadHoldingRegisters {
|
||||
regs, err = ms.handler.HandleHoldingRegisters(
|
||||
&HoldingRegistersRequest{
|
||||
ClientAddr: clientAddr,
|
||||
ClientRole: clientRole,
|
||||
UnitId: req.unitId,
|
||||
Addr: addr,
|
||||
Quantity: quantity,
|
||||
IsWrite: false,
|
||||
Args: nil,
|
||||
})
|
||||
} else {
|
||||
regs, err = ms.handler.HandleInputRegisters(
|
||||
&InputRegistersRequest{
|
||||
ClientAddr: clientAddr,
|
||||
ClientRole: clientRole,
|
||||
UnitId: req.unitId,
|
||||
Addr: addr,
|
||||
Quantity: quantity,
|
||||
})
|
||||
}
|
||||
resCount = len(regs)
|
||||
|
||||
// make sure the handler returned the expected number of items
|
||||
if err == nil && resCount != int(quantity) {
|
||||
ms.logger.Errorf("handler returned %v 16-bit values, " +
|
||||
"expected %v", resCount, quantity)
|
||||
err = ErrServerDeviceFailure
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// assemble a response PDU
|
||||
res = &pdu{
|
||||
unitId: req.unitId,
|
||||
functionCode: req.functionCode,
|
||||
payload: []byte{0},
|
||||
}
|
||||
|
||||
// byte count (2 bytes per register)
|
||||
res.payload[0] = uint8(resCount * 2)
|
||||
|
||||
// register values
|
||||
res.payload = append(res.payload,
|
||||
uint16sToBytes(BIG_ENDIAN, regs)...)
|
||||
|
||||
case fcWriteSingleRegister:
|
||||
var value uint16
|
||||
|
||||
if len(req.payload) != 4 {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
|
||||
// decode address and value fields
|
||||
addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2])
|
||||
value = bytesToUint16(BIG_ENDIAN, req.payload[2:4])
|
||||
|
||||
// invoke the handler
|
||||
_, err = ms.handler.HandleHoldingRegisters(
|
||||
&HoldingRegistersRequest{
|
||||
ClientAddr: clientAddr,
|
||||
ClientRole: clientRole,
|
||||
UnitId: req.unitId,
|
||||
Addr: addr,
|
||||
Quantity: 1, // request for a single register
|
||||
IsWrite: true, // request is a write
|
||||
Args: []uint16{value},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// assemble a response PDU
|
||||
res = &pdu{
|
||||
unitId: req.unitId,
|
||||
functionCode: req.functionCode,
|
||||
}
|
||||
|
||||
// echo the address and value in the response
|
||||
res.payload = append(res.payload,
|
||||
uint16ToBytes(BIG_ENDIAN, addr)...)
|
||||
res.payload = append(res.payload,
|
||||
uint16ToBytes(BIG_ENDIAN, value)...)
|
||||
|
||||
case fcWriteMultipleRegisters:
|
||||
var expectedLen int
|
||||
|
||||
if len(req.payload) < 6 {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
|
||||
// decode address and quantity fields
|
||||
addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2])
|
||||
quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4])
|
||||
|
||||
// ensure the reply never exceeds the maximum PDU length and we
|
||||
// never read past 0xffff
|
||||
if quantity > 0x007b || quantity == 0 {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
if uint32(addr) + uint32(quantity) - 1 > 0xffff {
|
||||
err = ErrIllegalDataAddress
|
||||
break
|
||||
}
|
||||
|
||||
// validate the byte count field (2 bytes per register)
|
||||
expectedLen = int(quantity) * 2
|
||||
|
||||
if req.payload[4] != uint8(expectedLen) {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
|
||||
// make sure we have enough bytes
|
||||
if len(req.payload) - 5 != expectedLen {
|
||||
err = ErrProtocolError
|
||||
break
|
||||
}
|
||||
|
||||
// invoke the holding register handler
|
||||
_, err = ms.handler.HandleHoldingRegisters(
|
||||
&HoldingRegistersRequest{
|
||||
ClientAddr: clientAddr,
|
||||
ClientRole: clientRole,
|
||||
UnitId: req.unitId,
|
||||
Addr: addr,
|
||||
Quantity: quantity,
|
||||
IsWrite: true, // this is a write request
|
||||
Args: bytesToUint16s(BIG_ENDIAN, req.payload[5:]),
|
||||
})
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
// assemble a response PDU
|
||||
res = &pdu{
|
||||
unitId: req.unitId,
|
||||
functionCode: req.functionCode,
|
||||
}
|
||||
|
||||
// echo the address and quantity in the response
|
||||
res.payload = append(res.payload,
|
||||
uint16ToBytes(BIG_ENDIAN, addr)...)
|
||||
res.payload = append(res.payload,
|
||||
uint16ToBytes(BIG_ENDIAN, quantity)...)
|
||||
|
||||
default:
|
||||
res = &pdu{
|
||||
// reply with the request target unit ID
|
||||
unitId: req.unitId,
|
||||
// set the error bit
|
||||
functionCode: (0x80 | req.functionCode),
|
||||
// set the exception code to illegal function to indicate that
|
||||
// the server does not know how to handle this function code.
|
||||
payload: []byte{exIllegalFunction},
|
||||
}
|
||||
}
|
||||
|
||||
// if there was no error processing the request but the response is nil
|
||||
// (which should never happen), emit a server failure exception code
|
||||
// and log an error
|
||||
if err == nil && res == nil {
|
||||
err = ErrServerDeviceFailure
|
||||
ms.logger.Errorf("internal server error (req: %v, res: %v, err: %v)",
|
||||
req, res, err)
|
||||
}
|
||||
|
||||
// map go errors to modbus errors, unless the error is a protocol error,
|
||||
// in which case close the transport and return.
|
||||
if err != nil {
|
||||
if err == ErrProtocolError {
|
||||
ms.logger.Warningf(
|
||||
"protocol error, closing link (client address: '%s')",
|
||||
clientAddr)
|
||||
t.Close()
|
||||
return
|
||||
} else {
|
||||
res = &pdu{
|
||||
unitId: req.unitId,
|
||||
functionCode: (0x80 | req.functionCode),
|
||||
payload: []byte{mapErrorToExceptionCode(err)},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// write the response to the transport
|
||||
err = t.WriteResponse(res)
|
||||
if err != nil {
|
||||
ms.logger.Warningf("failed to write response: %v", err)
|
||||
}
|
||||
|
||||
// avoid holding on to stale data
|
||||
req = nil
|
||||
res = nil
|
||||
}
|
||||
|
||||
// never reached
|
||||
return
|
||||
}
|
||||
|
||||
// startTLS performs a full TLS handshake (with client authentication) on tcpSock
|
||||
// and returns a 'wrapped' clear-text socket suitable for use by the TCP transport.
|
||||
func (ms *ModbusServer) startTLS(tcpSock net.Conn) (
|
||||
tlsSock *tls.Conn, clientRole string, err error) {
|
||||
var connState tls.ConnectionState
|
||||
|
||||
// set a 30s timeout for the TLS handshake to complete
|
||||
err = tcpSock.SetDeadline(time.Now().Add(30 * time.Second))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// start TLS negotiation over the raw TCP connection
|
||||
tlsSock = tls.Server(tcpSock, &tls.Config{
|
||||
Certificates: []tls.Certificate{
|
||||
*ms.conf.TLSServerCert,
|
||||
},
|
||||
ClientCAs: ms.conf.TLSClientCAs,
|
||||
// require a valid (verified) certificate from the client
|
||||
// (see R-06, R-08 and R-10 of the MBAPS spec)
|
||||
ClientAuth: tls.RequireAndVerifyClientCert,
|
||||
// mandate TLSv1.2 or higher (see R-01 of the MBAPS spec)
|
||||
MinVersion: tls.VersionTLS12,
|
||||
})
|
||||
|
||||
// complete the full TLS handshake (with client cert validation)
|
||||
err = tlsSock.Handshake()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// look for and extract the client's role, if any
|
||||
connState = tlsSock.ConnectionState()
|
||||
if len(connState.PeerCertificates) == 0 {
|
||||
err = errors.New("no client certificate received")
|
||||
return
|
||||
}
|
||||
// From the tls.ConnectionState doc:
|
||||
// "The first element is the leaf certificate that the connection is
|
||||
// verified against."
|
||||
clientRole = ms.extractRole(connState.PeerCertificates[0])
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// extractRole looks for Modbus Role extensions in a certificate and returns the
|
||||
// role as a string.
|
||||
// If no role extension is found, a nil string is returned (R-23).
|
||||
// If multiple or invalid role extensions are found, a nil string is returned (R-65, R-22).
|
||||
func (ms *ModbusServer) extractRole(cert *x509.Certificate) (role string) {
|
||||
var err error
|
||||
var found bool
|
||||
var badCert bool
|
||||
|
||||
// walk through all extensions looking for Modbus Role OIDs
|
||||
for _, ext := range cert.Extensions {
|
||||
if ext.Id.Equal(modbusRoleOID) {
|
||||
|
||||
// there must be only one role extension per cert (R-65)
|
||||
if found {
|
||||
ms.logger.Warning("client certificate contains more than one role OIDs")
|
||||
badCert = true
|
||||
break
|
||||
}
|
||||
found = true
|
||||
|
||||
// the role extension must use UTF8String encoding (R-22)
|
||||
// (the ASN1 tag for UTF8String is 0x0c)
|
||||
if len(ext.Value) < 2 || ext.Value[0] != 0x0c {
|
||||
badCert = true
|
||||
break
|
||||
}
|
||||
|
||||
// extract the ASN1 string
|
||||
_, err = asn1.Unmarshal(ext.Value, &role)
|
||||
if err != nil {
|
||||
ms.logger.Warningf("failed to decode Modbus Role extension: %v", err)
|
||||
badCert = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// blank the role if we found more than one Role extension
|
||||
if badCert {
|
||||
role = ""
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user