diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..dcaedbc --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Simon Vetter + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index e69de29..d66e92e 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,199 @@ +## Go modbus stack + +### Description +This package is a go implementation of the modbus protocol. +It aims to provide a simple-to-use, high-level API to interact with modbus +devices using native Go types. + +Both client and server components are available. + +The client supports the following modes: +- modbus RTU (serial, over both RS-232 and RS-485), +- modbus TCP (a.k.a. MBAP), +- modbus TCP over TLS (a.k.a. MBAPS or Modbus Security), +- modbus TCP over UDP (a.k.a. MBAP over UDP), +- modbus RTU over TCP (RTU tunneled in TCP for use with e.g. remote serial + ports or cheap TCP to serial bridges), +- modbus RTU over UDP (RTU tunneled in UDP). + +Please note that UDP transports are not part of the Modbus specification. +Some devices expect MBAP (modbus TCP) framing in UDP packets while others +use RTU frames instead. The client support both so if unsure, try with +both udp:// and rtuoverudp:// schemes. + +The server supports: +- modbus TCP (a.k.a. MBAP), +- modbus TCP over TLS (a.k.a. MBAPS or Modbus Security). + +A CLI client is available in cmd/modbus-cli.go and can be built with +```bash +$ go build -o modbus-cli cmd/modbus-cli.go +$ ./modbus-cli --help +``` + +### Getting started +```bash +$ go get github.com/simonvetter/modbus +``` + +### Using the client + +```golang +import ( + "github.com/simonvetter/modbus" +) + +func main() { + var client *modbus.ModbusClient + var err error + + // for a TCP endpoint + // (see examples/tls_client.go for TLS usage and options) + client, err = modbus.NewClient(&modbus.ClientConfiguration{ + URL: "tcp://hostname-or-ip-address:502", + Timeout: 1 * time.Second, + }) + // note: use udp:// for modbus TCP over UDP + + // for an RTU (serial) device/bus + client, err = modbus.NewClient(&modbus.ClientConfiguration{ + URL: "rtu:///dev/ttyUSB0", + Speed: 19200, // default + DataBits: 8, // default, optional + Parity: modbus.PARITY_NONE, // default, optional + StopBits: 2, // default if no parity, optional + Timeout: 300 * time.Millisecond, + }) + + // for an RTU over TCP device/bus (remote serial port or + // simple TCP-to-serial bridge) + client, err = modbus.NewClient(&modbus.ClientConfiguration{ + URL: "rtuovertcp://hostname-or-ip-address:502", + Speed: 19200, // serial link speed + Timeout: 1 * time.Second, + }) + // note: use rtuoverudp:// for modbus RTU over UDP + + if err != nil { + // error out if client creation failed + } + + // now that the client is created and configured, attempt to connect + err = client.Open() + if err != nil { + // error out if we failed to connect/open the device + // note: multiple Open() attempts can be made on the same client until + // the connection succeeds (i.e. err == nil), calling the constructor again + // is unnecessary. + // likewise, a client can be opened and closed as many times as needed. + } + + // read a single 16-bit holding register at address 100 + var reg16 uint16 + reg16, err = client.ReadRegister(100, modbus.HOLDING_REGISTER) + if err != nil { + // error out + } else { + // use value + fmt.Printf("value: %v", reg16) // as unsigned integer + fmt.Printf("value: %v", int16(reg16)) // as signed integer + } + + // read 4 consecutive 16-bit input registers starting at address 100 + var reg16s []uint16 + reg16s, err = client.ReadRegisters(100, 4, modbus.INPUT_REGISTER) + + // read the same 4 consecutive 16-bit input registers as 2 32-bit integers + var reg32s []uint32 + reg32s, err = client.ReadUint32s(100, 2, modbus.INPUT_REGISTER) + + // read the same 4 consecutive 16-bit registers as a single 64-bit integer + var reg64 uint64 + reg64, err = client.ReadUint64(100, modbus.INPUT_REGISTER) + + // read the same 4 consecutive 16-bit registers as a slice of bytes + var regBs []byte + regBs, err = client.ReadBytes(100, 8, modbus.INPUT_REGISTER) + + // by default, 16-bit integers are decoded as big-endian and 32/64-bit values as + // big-endian with the high word first. + // change the byte/word ordering of subsequent requests to little endian, with + // the low word first (note that the second argument only affects 32/64-bit values) + client.SetEncoding(modbus.LITTLE_ENDIAN, modbus.LOW_WORD_FIRST) + + // read the same 4 consecutive 16-bit input registers as 2 32-bit floats + var fl32s []float32 + fl32s, err = client.ReadFloat32s(100, 2, modbus.INPUT_REGISTER) + + // write -200 to 16-bit (holding) register 100, as a signed integer + var s int16 = -200 + err = client.WriteRegister(100, uint16(s)) + + // Switch to unit ID (a.k.a. slave ID) #4 + client.SetUnitId(4) + + // write 3 floats to registers 100 to 105 + err = client.WriteFloat32s(100, []float32{ + 3.14, + 1.1, + -783.22, + }) + + // write 0x0102030405060708 to 16-bit (holding) registers 10 through 13 + // (8 bytes i.e. 4 consecutive modbus registers) + err = client.WriteBytes(10, []byte{ + 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, + }) + + // close the TCP connection/serial port + client.Close() +} +``` +### Using the server component +See: +* [examples/tcp_server.go](examples/tcp_server.go) for a modbus TCP example +* [examples/tls_server.go](examples/tls_server.go) for TLS and Modbus Security features + +### Supported function codes, golang object types and endianness/word ordering +Function codes: +* Read coils (0x01) +* Read discrete inputs (0x02) +* Read holding registers (0x03) +* Read input registers (0x04) +* Write single coil (0x05) +* Write single register (0x06) +* Write multiple coils (0x0f) +* Write multiple registers (0x10) + +Go object types: +* Booleans (coils and discrete inputs) +* Bytes (input and holding registers) +* Signed/Unisgned 16-bit integers (input and holding registers) +* Signed/Unsigned 32-bit integers (input and holding registers) +* 32-bit floating point numbers (input and holding registers) +* Signed/Unsigned 64-bit integers (input and holding registers) +* 64-bit floating point numbers (input and holding registers) + +Byte encoding/endianness/word ordering: +* Little and Big endian for byte slices and 16-bit integers +* Little and Big endian, with and without word swap for 32 and 64-bit + integers and floating point numbers. + +### Logging ### +Both client and server objects will log to stdout by default. +This behavior can be overriden by passing a log.Logger object +through the Logger property of ClientConfiguration/ServerConfiguration. + +### TODO (in no particular order) +* Add RTU (serial) support to the server +* Add more tests +* Add diagnostics register support +* Add fifo register support +* Add file register support + +### Dependencies +* [github.com/goburrow/serial](https://github.com/goburrow/serial) for access to the serial port (thanks!) + +### License +MIT. diff --git a/client.go b/client.go new file mode 100644 index 0000000..ed1302f --- /dev/null +++ b/client.go @@ -0,0 +1,1320 @@ +package modbus + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "log" + "net" + "os" + "strings" + "sync" + "time" +) + +type RegType uint +type Endianness uint +type WordOrder uint + +const ( + PARITY_NONE uint = 0 + PARITY_EVEN uint = 1 + PARITY_ODD uint = 2 + + HOLDING_REGISTER RegType = 0 + INPUT_REGISTER RegType = 1 + + // endianness of 16-bit registers + BIG_ENDIAN Endianness = 1 + LITTLE_ENDIAN Endianness = 2 + + // word order of 32-bit registers + HIGH_WORD_FIRST WordOrder = 1 + LOW_WORD_FIRST WordOrder = 2 +) + +// Modbus client configuration object. +type ClientConfiguration struct { + // URL sets the client mode and target location in the form + // :// e.g. tcp://plc:502 + URL string + // Speed sets the serial link speed (in bps, rtu only) + Speed uint + // DataBits sets the number of bits per serial character (rtu only) + DataBits uint + // Parity sets the serial link parity mode (rtu only) + Parity uint + // StopBits sets the number of serial stop bits (rtu only) + StopBits uint + // Timeout sets the request timeout value + Timeout time.Duration + // TLSClientCert sets the client-side TLS key pair (tcp+tls only) + TLSClientCert *tls.Certificate + // TLSRootCAs sets the list of CA certificates used to authenticate + // the server (tcp+tls only). Leaf (i.e. server) certificates can also + // be used in case of self-signed certs, or if cert pinning is required. + TLSRootCAs *x509.CertPool + // Logger provides a custom sink for log messages. + // If nil, messages will be written to stdout. + Logger *log.Logger +} + +// Modbus client object. +type ModbusClient struct { + conf ClientConfiguration + logger *logger + lock sync.Mutex + endianness Endianness + wordOrder WordOrder + transport transport + unitId uint8 + transportType transportType +} + +// NewClient creates, configures and returns a modbus client object. +func NewClient(conf *ClientConfiguration) (mc *ModbusClient, err error) { + var clientType string + var splitURL []string + + mc = &ModbusClient{ + conf: *conf, + } + + splitURL = strings.SplitN(mc.conf.URL, "://", 2) + if len(splitURL) == 2 { + clientType = splitURL[0] + mc.conf.URL = splitURL[1] + } + + mc.logger = newLogger( + fmt.Sprintf("modbus-client(%s)", mc.conf.URL), conf.Logger) + + switch clientType { + case "rtu": + // set useful defaults + if mc.conf.Speed == 0 { + mc.conf.Speed = 19200 + } + + // note: the "modbus over serial line v1.02" document specifies an + // 11-bit character frame, with even parity and 1 stop bit as default, + // and mandates the use of 2 stop bits when no parity is used. + // This stack defaults to 8/N/2 as most devices seem to use no parity, + // but giving 8/N/1, 8/E/1 and 8/O/1 a shot may help with serial + // issues. + if mc.conf.DataBits == 0 { + mc.conf.DataBits = 8 + } + + if mc.conf.StopBits == 0 { + if mc.conf.Parity == PARITY_NONE { + mc.conf.StopBits = 2 + } else { + mc.conf.StopBits = 1 + } + } + + if mc.conf.Timeout == 0 { + mc.conf.Timeout = 300 * time.Millisecond + } + + mc.transportType = modbusRTU + + case "rtuovertcp": + if mc.conf.Speed == 0 { + mc.conf.Speed = 19200 + } + + if mc.conf.Timeout == 0 { + mc.conf.Timeout = 1 * time.Second + } + + mc.transportType = modbusRTUOverTCP + + case "rtuoverudp": + if mc.conf.Speed == 0 { + mc.conf.Speed = 19200 + } + + if mc.conf.Timeout == 0 { + mc.conf.Timeout = 1 * time.Second + } + + mc.transportType = modbusRTUOverUDP + + case "tcp": + if mc.conf.Timeout == 0 { + mc.conf.Timeout = 1 * time.Second + } + + mc.transportType = modbusTCP + + case "tcp+tls": + if mc.conf.Timeout == 0 { + mc.conf.Timeout = 1 * time.Second + } + + // expect a client-side certificate for mutual auth as the + // modbus/mpab protocol has no inherent auth facility. + // (see requirements R-08 and R-19 of the MBAPS spec) + if mc.conf.TLSClientCert == nil { + mc.logger.Errorf("missing client certificate") + err = ErrConfigurationError + return + } + + // expect a CertPool object containing at least 1 CA or + // leaf certificate to validate the server-side cert + if mc.conf.TLSRootCAs == nil { + mc.logger.Errorf("missing CA/server certificate") + err = ErrConfigurationError + return + } + + mc.transportType = modbusTCPOverTLS + + case "udp": + if mc.conf.Timeout == 0 { + mc.conf.Timeout = 1 * time.Second + } + + mc.transportType = modbusTCPOverUDP + + default: + if len(splitURL) != 2 { + mc.logger.Errorf("missing client type in URL '%s'", mc.conf.URL) + } else { + mc.logger.Errorf("unsupported client type '%s'", clientType) + } + err = ErrConfigurationError + return + } + + mc.unitId = 1 + mc.endianness = BIG_ENDIAN + mc.wordOrder = HIGH_WORD_FIRST + + return +} + +// Opens the underlying transport (network socket or serial line). +func (mc *ModbusClient) Open() (err error) { + var spw *serialPortWrapper + var sock net.Conn + + mc.lock.Lock() + defer mc.lock.Unlock() + + switch mc.transportType { + case modbusRTU: + // create a serial port wrapper object + spw = newSerialPortWrapper(&serialPortConfig{ + Device: mc.conf.URL, + Speed: mc.conf.Speed, + DataBits: mc.conf.DataBits, + Parity: mc.conf.Parity, + StopBits: mc.conf.StopBits, + }) + + // open the serial device + err = spw.Open() + if err != nil { + return + } + + // discard potentially stale serial data + discard(spw) + + // create the RTU transport + mc.transport = newRTUTransport( + spw, mc.conf.URL, mc.conf.Speed, mc.conf.Timeout, mc.conf.Logger) + + case modbusRTUOverTCP: + // connect to the remote host + sock, err = net.DialTimeout("tcp", mc.conf.URL, 5*time.Second) + if err != nil { + return + } + + // discard potentially stale serial data + discard(sock) + + // create the RTU transport + mc.transport = newRTUTransport( + sock, mc.conf.URL, mc.conf.Speed, mc.conf.Timeout, mc.conf.Logger) + + case modbusRTUOverUDP: + // open a socket to the remote host (note: no actual connection is + // being made as UDP is connection-less) + sock, err = net.DialTimeout("udp", mc.conf.URL, 5*time.Second) + if err != nil { + return + } + + // create the RTU transport, wrapping the UDP socket in + // an adapter to allow the transport to read the stream of + // packets byte per byte + mc.transport = newRTUTransport( + newUDPSockWrapper(sock), + mc.conf.URL, mc.conf.Speed, mc.conf.Timeout, mc.conf.Logger) + + case modbusTCP: + // connect to the remote host + sock, err = net.DialTimeout("tcp", mc.conf.URL, 5*time.Second) + if err != nil { + return + } + + // create the TCP transport + mc.transport = newTCPTransport(sock, mc.conf.Timeout, mc.conf.Logger) + + case modbusTCPOverTLS: + // connect to the remote host with TLS + sock, err = tls.DialWithDialer( + &net.Dialer{ + Deadline: time.Now().Add(15 * time.Second), + }, "tcp", mc.conf.URL, + &tls.Config{ + Certificates: []tls.Certificate{ + *mc.conf.TLSClientCert, + }, + RootCAs: mc.conf.TLSRootCAs, + // mandate TLS 1.2 or higher (see R-01 of the MBAPS spec) + MinVersion: tls.VersionTLS12, + }) + if err != nil { + return + } + + // force the TLS handshake + err = sock.(*tls.Conn).Handshake() + if err != nil { + sock.Close() + return + } + + // create the TCP transport, wrapping the TLS socket in + // an adapter to work around write timeouts corrupting internal + // state (see https://pkg.go.dev/crypto/tls#Conn.SetWriteDeadline) + mc.transport = newTCPTransport( + newTLSSockWrapper(sock), mc.conf.Timeout, mc.conf.Logger) + + case modbusTCPOverUDP: + // open a socket to the remote host (note: no actual connection is + // being made as UDP is connection-less) + sock, err = net.DialTimeout("udp", mc.conf.URL, 5*time.Second) + if err != nil { + return + } + + // create the TCP transport, wrapping the UDP socket in + // an adapter to allow the transport to read the stream of + // packets byte per byte + mc.transport = newTCPTransport( + newUDPSockWrapper(sock), mc.conf.Timeout, mc.conf.Logger) + + default: + // should never happen + err = ErrConfigurationError + } + + return +} + +// Closes the underlying transport. +func (mc *ModbusClient) Close() (err error) { + mc.lock.Lock() + defer mc.lock.Unlock() + + if mc.transport != nil { + err = mc.transport.Close() + } + + return +} + +// Sets the unit id of subsequent requests. +func (mc *ModbusClient) SetUnitId(id uint8) (err error) { + mc.lock.Lock() + defer mc.lock.Unlock() + + mc.unitId = id + + return +} + +// Sets the encoding (endianness and word ordering) of subsequent requests. +func (mc *ModbusClient) SetEncoding(endianness Endianness, wordOrder WordOrder) (err error) { + mc.lock.Lock() + defer mc.lock.Unlock() + + if endianness != BIG_ENDIAN && endianness != LITTLE_ENDIAN { + mc.logger.Errorf("unknown endianness value %v", endianness) + err = ErrUnexpectedParameters + return + } + + if wordOrder != HIGH_WORD_FIRST && wordOrder != LOW_WORD_FIRST { + mc.logger.Errorf("unknown word order value %v", wordOrder) + err = ErrUnexpectedParameters + return + } + + mc.endianness = endianness + mc.wordOrder = wordOrder + + return +} + +// Reads multiple coils (function code 01). +func (mc *ModbusClient) ReadCoils(addr uint16, quantity uint16) (values []bool, err error) { + values, err = mc.readBools(addr, quantity, false) + + return +} + +// Reads a single coil (function code 01). +func (mc *ModbusClient) ReadCoil(addr uint16) (value bool, err error) { + var values []bool + + values, err = mc.readBools(addr, 1, false) + if err == nil { + value = values[0] + } + + return +} + +// Reads multiple discrete inputs (function code 02). +func (mc *ModbusClient) ReadDiscreteInputs(addr uint16, quantity uint16) (values []bool, err error) { + values, err = mc.readBools(addr, quantity, true) + + return +} + +// Reads a single discrete input (function code 02). +func (mc *ModbusClient) ReadDiscreteInput(addr uint16) (value bool, err error) { + var values []bool + + values, err = mc.readBools(addr, 1, true) + if err == nil { + value = values[0] + } + + return +} + +// Reads multiple 16-bit registers (function code 03 or 04). +func (mc *ModbusClient) ReadRegisters(addr uint16, quantity uint16, regType RegType) (values []uint16, err error) { + var mbPayload []byte + + // read quantity uint16 registers, as bytes + mbPayload, err = mc.readRegisters(addr, quantity, regType) + if err != nil { + return + } + + // decode payload bytes as uint16s + values = bytesToUint16s(mc.endianness, mbPayload) + + return +} + +// Reads multiple 16-bit registers with function code +func (mc *ModbusClient) ReadRegistersWithFunctionCode(addr uint16, quantity uint16, funcCode uint8) (values []uint16, err error) { + var mbPayload []byte + + // read quantity uint16 registers, as bytes + mbPayload, err = mc.readRegistersWithFunctionCode(addr, quantity, funcCode) + if err != nil { + return + } + + // decode payload bytes as uint16s + values = bytesToUint16s(mc.endianness, mbPayload) + + return +} + +// Reads a single 16-bit register (function code 03 or 04). +func (mc *ModbusClient) ReadRegister(addr uint16, regType RegType) (value uint16, err error) { + var values []uint16 + + // read 1 uint16 register, as bytes + values, err = mc.ReadRegisters(addr, 1, regType) + if err == nil { + value = values[0] + } + + return +} + +// Reads multiple 32-bit registers. +func (mc *ModbusClient) ReadUint32s(addr uint16, quantity uint16, regType RegType) (values []uint32, err error) { + var mbPayload []byte + + // read 2 * quantity uint16 registers, as bytes + mbPayload, err = mc.readRegisters(addr, quantity*2, regType) + if err != nil { + return + } + + // decode payload bytes as uint32s + values = bytesToUint32s(mc.endianness, mc.wordOrder, mbPayload) + + return +} + +// Reads a single 32-bit register. +func (mc *ModbusClient) ReadUint32(addr uint16, regType RegType) (value uint32, err error) { + var values []uint32 + + values, err = mc.ReadUint32s(addr, 1, regType) + if err == nil { + value = values[0] + } + + return +} + +// Reads multiple 32-bit float registers. +func (mc *ModbusClient) ReadFloat32s(addr uint16, quantity uint16, regType RegType) (values []float32, err error) { + var mbPayload []byte + + // read 2 * quantity uint16 registers, as bytes + mbPayload, err = mc.readRegisters(addr, quantity*2, regType) + if err != nil { + return + } + + // decode payload bytes as float32s + values = bytesToFloat32s(mc.endianness, mc.wordOrder, mbPayload) + + return +} + +// Reads a single 32-bit float register. +func (mc *ModbusClient) ReadFloat32(addr uint16, regType RegType) (value float32, err error) { + var values []float32 + + values, err = mc.ReadFloat32s(addr, 1, regType) + if err == nil { + value = values[0] + } + + return +} + +// Reads multiple 64-bit registers. +func (mc *ModbusClient) ReadUint64s(addr uint16, quantity uint16, regType RegType) (values []uint64, err error) { + var mbPayload []byte + + // read 4 * quantity uint16 registers, as bytes + mbPayload, err = mc.readRegisters(addr, quantity*4, regType) + if err != nil { + return + } + + // decode payload bytes as uint64s + values = bytesToUint64s(mc.endianness, mc.wordOrder, mbPayload) + + return +} + +// Reads a single 64-bit register. +func (mc *ModbusClient) ReadUint64(addr uint16, regType RegType) (value uint64, err error) { + var values []uint64 + + values, err = mc.ReadUint64s(addr, 1, regType) + if err == nil { + value = values[0] + } + + return +} + +// Reads multiple 64-bit float registers. +func (mc *ModbusClient) ReadFloat64s(addr uint16, quantity uint16, regType RegType) (values []float64, err error) { + var mbPayload []byte + + // read 4 * quantity uint16 registers, as bytes + mbPayload, err = mc.readRegisters(addr, quantity*4, regType) + if err != nil { + return + } + + // decode payload bytes as float64s + values = bytesToFloat64s(mc.endianness, mc.wordOrder, mbPayload) + + return +} + +// Reads a single 64-bit float register. +func (mc *ModbusClient) ReadFloat64(addr uint16, regType RegType) (value float64, err error) { + var values []float64 + + values, err = mc.ReadFloat64s(addr, 1, regType) + if err == nil { + value = values[0] + } + + return +} + +// Reads one or multiple 16-bit registers (function code 03 or 04) as bytes. +// A per-register byteswap is performed if endianness is set to LITTLE_ENDIAN. +func (mc *ModbusClient) ReadBytes(addr uint16, quantity uint16, regType RegType) (values []byte, err error) { + values, err = mc.readBytes(addr, quantity, regType, true) + + return +} + +// Reads one or multiple 16-bit registers (function code 03 or 04) as bytes. +// No byte or word reordering is performed: bytes are returned exactly as they come +// off the wire, allowing the caller to handle encoding/endianness/word order manually. +func (mc *ModbusClient) ReadRawBytes(addr uint16, quantity uint16, regType RegType) (values []byte, err error) { + values, err = mc.readBytes(addr, quantity, regType, false) + + return +} + +// Writes a single coil (function code 05) +func (mc *ModbusClient) WriteCoil(addr uint16, value bool) (err error) { + var req *pdu + var res *pdu + + mc.lock.Lock() + defer mc.lock.Unlock() + + // create and fill in the request object + req = &pdu{ + unitId: mc.unitId, + functionCode: fcWriteSingleCoil, + } + + // coil address + req.payload = uint16ToBytes(BIG_ENDIAN, addr) + // coil value + if value { + req.payload = append(req.payload, 0xff, 0x00) + } else { + req.payload = append(req.payload, 0x00, 0x00) + } + + // run the request across the transport and wait for a response + res, err = mc.executeRequest(req) + if err != nil { + return + } + + // validate the response code + switch { + case res.functionCode == req.functionCode: + // expect 4 bytes (2 byte of address + 2 bytes of value) + if len(res.payload) != 4 || + // bytes 1-2 should be the coil address + bytesToUint16(BIG_ENDIAN, res.payload[0:2]) != addr || + // bytes 3-4 should either be {0xff, 0x00} or {0x00, 0x00} + // depending on the coil value + (value == true && res.payload[2] != 0xff) || + res.payload[3] != 0x00 { + err = ErrProtocolError + return + } + + case res.functionCode == (req.functionCode | 0x80): + if len(res.payload) != 1 { + err = ErrProtocolError + return + } + + err = mapExceptionCodeToError(res.payload[0]) + + default: + err = ErrProtocolError + mc.logger.Warningf("unexpected response code (%v)", res.functionCode) + } + + return +} + +// Writes multiple coils (function code 15) +func (mc *ModbusClient) WriteCoils(addr uint16, values []bool) (err error) { + var req *pdu + var res *pdu + var quantity uint16 + var encodedValues []byte + + mc.lock.Lock() + defer mc.lock.Unlock() + + quantity = uint16(len(values)) + if quantity == 0 { + err = ErrUnexpectedParameters + mc.logger.Error("quantity of coils is 0") + return + } + + if quantity > 0x7b0 { + err = ErrUnexpectedParameters + mc.logger.Error("quantity of coils exceeds 1968") + return + } + + if uint32(addr)+uint32(quantity)-1 > 0xffff { + err = ErrUnexpectedParameters + mc.logger.Error("end coil address is past 0xffff") + return + } + + encodedValues = encodeBools(values) + + // create and fill in the request object + req = &pdu{ + unitId: mc.unitId, + functionCode: fcWriteMultipleCoils, + } + + // start address + req.payload = uint16ToBytes(BIG_ENDIAN, addr) + // quantity + req.payload = append(req.payload, uint16ToBytes(BIG_ENDIAN, quantity)...) + // byte count + req.payload = append(req.payload, byte(len(encodedValues))) + // payload + req.payload = append(req.payload, encodedValues...) + + // run the request across the transport and wait for a response + res, err = mc.executeRequest(req) + if err != nil { + return + } + + // validate the response code + switch { + case res.functionCode == req.functionCode: + // expect 4 bytes (2 byte of address + 2 bytes of quantity) + if len(res.payload) != 4 || + // bytes 1-2 should be the base coil address + bytesToUint16(BIG_ENDIAN, res.payload[0:2]) != addr || + // bytes 3-4 should be the quantity of coils + bytesToUint16(BIG_ENDIAN, res.payload[2:4]) != quantity { + err = ErrProtocolError + return + } + + case res.functionCode == (req.functionCode | 0x80): + if len(res.payload) != 1 { + err = ErrProtocolError + return + } + + err = mapExceptionCodeToError(res.payload[0]) + + default: + err = ErrProtocolError + mc.logger.Warningf("unexpected response code (%v)", res.functionCode) + } + + return +} + +// Writes a single 16-bit register (function code 06). +func (mc *ModbusClient) WriteRegister(addr uint16, value uint16) (err error) { + var req *pdu + var res *pdu + + mc.lock.Lock() + defer mc.lock.Unlock() + + // create and fill in the request object + req = &pdu{ + unitId: mc.unitId, + functionCode: fcWriteSingleRegister, + } + + // register address + req.payload = uint16ToBytes(BIG_ENDIAN, addr) + // register value + req.payload = append(req.payload, uint16ToBytes(mc.endianness, value)...) + + // run the request across the transport and wait for a response + res, err = mc.executeRequest(req) + if err != nil { + return + } + + // validate the response code + switch { + case res.functionCode == req.functionCode: + // expect 4 bytes (2 byte of address + 2 bytes of value) + if len(res.payload) != 4 || + // bytes 1-2 should be the register address + bytesToUint16(BIG_ENDIAN, res.payload[0:2]) != addr || + // bytes 3-4 should be the value + bytesToUint16(mc.endianness, res.payload[2:4]) != value { + err = ErrProtocolError + return + } + + case res.functionCode == (req.functionCode | 0x80): + if len(res.payload) != 1 { + err = ErrProtocolError + return + } + + err = mapExceptionCodeToError(res.payload[0]) + + default: + err = ErrProtocolError + mc.logger.Warningf("unexpected response code (%v)", res.functionCode) + } + + return +} + +// Writes multiple 16-bit registers (function code 16). +func (mc *ModbusClient) WriteRegisters(addr uint16, values []uint16) (err error) { + var payload []byte + + // turn registers to bytes + for _, value := range values { + payload = append(payload, uint16ToBytes(mc.endianness, value)...) + } + + err = mc.writeRegisters(addr, payload) + + return +} + +// Writes multiple 32-bit registers. +func (mc *ModbusClient) WriteUint32s(addr uint16, values []uint32) (err error) { + var payload []byte + + // turn registers to bytes + for _, value := range values { + payload = append(payload, uint32ToBytes(mc.endianness, mc.wordOrder, value)...) + } + + err = mc.writeRegisters(addr, payload) + + return +} + +// Writes a single 32-bit register. +func (mc *ModbusClient) WriteUint32(addr uint16, value uint32) (err error) { + err = mc.writeRegisters(addr, uint32ToBytes(mc.endianness, mc.wordOrder, value)) + + return +} + +// Writes multiple 32-bit float registers. +func (mc *ModbusClient) WriteFloat32s(addr uint16, values []float32) (err error) { + var payload []byte + + // turn registers to bytes + for _, value := range values { + payload = append(payload, float32ToBytes(mc.endianness, mc.wordOrder, value)...) + } + + err = mc.writeRegisters(addr, payload) + + return +} + +// Writes a single 32-bit float register. +func (mc *ModbusClient) WriteFloat32(addr uint16, value float32) (err error) { + err = mc.writeRegisters(addr, float32ToBytes(mc.endianness, mc.wordOrder, value)) + + return +} + +// Writes multiple 64-bit registers. +func (mc *ModbusClient) WriteUint64s(addr uint16, values []uint64) (err error) { + var payload []byte + + // turn registers to bytes + for _, value := range values { + payload = append(payload, uint64ToBytes(mc.endianness, mc.wordOrder, value)...) + } + + err = mc.writeRegisters(addr, payload) + + return +} + +// Writes a single 64-bit register. +func (mc *ModbusClient) WriteUint64(addr uint16, value uint64) (err error) { + err = mc.writeRegisters(addr, uint64ToBytes(mc.endianness, mc.wordOrder, value)) + + return +} + +// Writes multiple 64-bit float registers. +func (mc *ModbusClient) WriteFloat64s(addr uint16, values []float64) (err error) { + var payload []byte + + // turn registers to bytes + for _, value := range values { + payload = append(payload, float64ToBytes(mc.endianness, mc.wordOrder, value)...) + } + + err = mc.writeRegisters(addr, payload) + + return +} + +// Writes a single 64-bit float register. +func (mc *ModbusClient) WriteFloat64(addr uint16, value float64) (err error) { + err = mc.writeRegisters(addr, float64ToBytes(mc.endianness, mc.wordOrder, value)) + + return +} + +// Writes the given slice of bytes to 16-bit registers starting at addr. +// A per-register byteswap is performed if endianness is set to LITTLE_ENDIAN. +// Odd byte quantities are padded with a null byte to fall on 16-bit register boundaries. +func (mc *ModbusClient) WriteBytes(addr uint16, values []byte) (err error) { + err = mc.writeBytes(addr, values, true) + + return +} + +// Writes the given slice of bytes to 16-bit registers starting at addr. +// No byte or word reordering is performed: bytes are pushed to the wire as-is, +// allowing the caller to handle encoding/endianness/word order manually. +// Odd byte quantities are padded with a null byte to fall on 16-bit register boundaries. +func (mc *ModbusClient) WriteRawBytes(addr uint16, values []byte) (err error) { + err = mc.writeBytes(addr, values, false) + + return +} + +/*** unexported methods ***/ +// Reads one or multiple 16-bit registers (function code 03 or 04) as bytes. +func (mc *ModbusClient) readBytes(addr uint16, quantity uint16, regType RegType, observeEndianness bool) (values []byte, err error) { + var regCount uint16 + + // read enough registers to get the requested number of bytes + // (2 bytes per reg) + regCount = (quantity / 2) + (quantity % 2) + + values, err = mc.readRegisters(addr, regCount, regType) + if err != nil { + return + } + + // swap bytes on register boundaries if requested by the caller + // and endianness is set to little endian + if observeEndianness && mc.endianness == LITTLE_ENDIAN { + for i := 0; i < len(values); i += 2 { + values[i], values[i+1] = values[i+1], values[i] + } + } + + // pop the last byte on odd quantities + if quantity%2 == 1 { + values = values[0 : len(values)-1] + } + + return +} + +// Writes the given slice of bytes to 16-bit registers starting at addr. +func (mc *ModbusClient) writeBytes(addr uint16, values []byte, observeEndianness bool) (err error) { + // pad odd quantities to make for full registers + if len(values)%2 == 1 { + values = append(values, 0x00) + } + + // swap bytes on register boundaries if requested by the caller + // and endianness is set to little endian + if observeEndianness && mc.endianness == LITTLE_ENDIAN { + for i := 0; i < len(values); i += 2 { + values[i], values[i+1] = values[i+1], values[i] + } + } + + err = mc.writeRegisters(addr, values) + + return +} + +// Reads and returns quantity booleans. +// Digital inputs are read if di is true, otherwise coils are read. +func (mc *ModbusClient) readBools(addr uint16, quantity uint16, di bool) (values []bool, err error) { + var req *pdu + var res *pdu + var expectedLen int + + mc.lock.Lock() + defer mc.lock.Unlock() + + if quantity == 0 { + err = ErrUnexpectedParameters + mc.logger.Error("quantity of coils/discrete inputs is 0") + return + } + + if quantity > 2000 { + err = ErrUnexpectedParameters + mc.logger.Error("quantity of coils/discrete inputs exceeds 2000") + return + } + + if uint32(addr)+uint32(quantity)-1 > 0xffff { + err = ErrUnexpectedParameters + mc.logger.Error("end coil/discrete input address is past 0xffff") + return + } + + // create and fill in the request object + req = &pdu{ + unitId: mc.unitId, + } + + if di { + req.functionCode = fcReadDiscreteInputs + } else { + req.functionCode = fcReadCoils + } + + // start address + req.payload = uint16ToBytes(BIG_ENDIAN, addr) + // quantity + req.payload = append(req.payload, uint16ToBytes(BIG_ENDIAN, quantity)...) + + // run the request across the transport and wait for a response + res, err = mc.executeRequest(req) + if err != nil { + return + } + + // validate the response code + switch { + case res.functionCode == req.functionCode: + // expect a payload of 1 byte (byte count) + 1 byte for 8 coils/discrete inputs) + expectedLen = 1 + expectedLen += int(quantity) / 8 + if quantity%8 != 0 { + expectedLen++ + } + + if len(res.payload) != expectedLen { + err = ErrProtocolError + return + } + + // validate the byte count field + if int(res.payload[0])+1 != expectedLen { + err = ErrProtocolError + return + } + + // turn bits into a bool slice + values = decodeBools(quantity, res.payload[1:]) + + case res.functionCode == (req.functionCode | 0x80): + if len(res.payload) != 1 { + err = ErrProtocolError + return + } + + err = mapExceptionCodeToError(res.payload[0]) + + default: + err = ErrProtocolError + mc.logger.Warningf("unexpected response code (%v)", res.functionCode) + } + + return +} + +// Reads and returns quantity registers of type regType, as bytes. +func (mc *ModbusClient) readRegisters(addr uint16, quantity uint16, regType RegType) (bytes []byte, err error) { + var req *pdu + var res *pdu + + mc.lock.Lock() + defer mc.lock.Unlock() + + // create and fill in the request object + req = &pdu{ + unitId: mc.unitId, + } + + switch regType { + case HOLDING_REGISTER: + req.functionCode = fcReadHoldingRegisters + case INPUT_REGISTER: + req.functionCode = fcReadInputRegisters + default: + err = ErrUnexpectedParameters + mc.logger.Errorf("unexpected register type (%v)", regType) + return + } + + if quantity == 0 { + err = ErrUnexpectedParameters + mc.logger.Error("quantity of registers is 0") + return + } + + if quantity > 125 { + err = ErrUnexpectedParameters + mc.logger.Error("quantity of registers exceeds 125") + return + } + + if uint32(addr)+uint32(quantity)-1 > 0xffff { + err = ErrUnexpectedParameters + mc.logger.Error("end register address is past 0xffff") + return + } + + // start address + req.payload = uint16ToBytes(BIG_ENDIAN, addr) + // quantity + req.payload = append(req.payload, uint16ToBytes(BIG_ENDIAN, quantity)...) + + // run the request across the transport and wait for a response + res, err = mc.executeRequest(req) + if err != nil { + return + } + + // validate the response code + switch { + case res.functionCode == req.functionCode: + // make sure the payload length is what we expect + // (1 byte of length + 2 bytes per register) + if len(res.payload) != 1+2*int(quantity) { + err = ErrProtocolError + return + } + + // validate the byte count field + // (2 bytes per register * number of registers) + if uint(res.payload[0]) != 2*uint(quantity) { + err = ErrProtocolError + return + } + + // remove the byte count field from the returned slice + bytes = res.payload[1:] + + case res.functionCode == (req.functionCode | 0x80): + if len(res.payload) != 1 { + err = ErrProtocolError + return + } + + err = mapExceptionCodeToError(res.payload[0]) + + default: + err = ErrProtocolError + mc.logger.Warningf("unexpected response code (%v)", res.functionCode) + } + + return +} + +// Reads and returns quantity registers of type regType, as bytes. +func (mc *ModbusClient) readRegistersWithFunctionCode(addr uint16, quantity uint16, functionCode uint8) (bytes []byte, err error) { + var req *pdu + var res *pdu + + mc.lock.Lock() + defer mc.lock.Unlock() + + // create and fill in the request object + req = &pdu{ + unitId: mc.unitId, + } + + if functionCode != fcCustomize { + err = ErrUnexpectedParameters + mc.logger.Errorf("unexpected function code (%d)", functionCode) + return + } + + if functionCode == 0 { + err = ErrUnexpectedParameters + mc.logger.Errorf("unexpected register type (%v)", functionCode) + return + } + + req.functionCode = functionCode + + if quantity == 0 { + err = ErrUnexpectedParameters + mc.logger.Error("quantity of registers is 0") + return + } + + if quantity > 1024 { + err = ErrUnexpectedParameters + mc.logger.Error("quantity of registers exceeds 1024") + return + } + + if uint32(addr)+uint32(quantity)-1 > 0xffff { + err = ErrUnexpectedParameters + mc.logger.Error("end register address is past 0xffff") + return + } + + // start address + req.payload = uint16ToBytes(BIG_ENDIAN, addr) + // quantity + req.payload = append(req.payload, uint16ToBytes(BIG_ENDIAN, quantity)...) + + // run the request across the transport and wait for a response + res, err = mc.executeRequest(req) + if err != nil { + return + } + + // validate the response code + switch { + case res.functionCode == req.functionCode: + // make sure the payload length is what we expect + // (1 byte of length + 2 bytes per register) + if len(res.payload) != 1+2*int(quantity) { + err = ErrProtocolError + return + } + + // validate the byte count field + // (2 bytes per register * number of registers) + if uint(res.payload[0]) != 2*uint(quantity) { + err = ErrProtocolError + return + } + + // remove the byte count field from the returned slice + bytes = res.payload[1:] + + case res.functionCode == (req.functionCode | 0x80): + if len(res.payload) != 1 { + err = ErrProtocolError + return + } + + err = mapExceptionCodeToError(res.payload[0]) + + default: + err = ErrProtocolError + mc.logger.Warningf("unexpected response code (%v)", res.functionCode) + } + + return +} + +// Writes multiple registers starting from base address addr. +// Register values are passed as bytes, each value being exactly 2 bytes. +func (mc *ModbusClient) writeRegisters(addr uint16, values []byte) (err error) { + var req *pdu + var res *pdu + var payloadLength uint16 + var quantity uint16 + + mc.lock.Lock() + defer mc.lock.Unlock() + + payloadLength = uint16(len(values)) + quantity = payloadLength / 2 + + if quantity == 0 { + err = ErrUnexpectedParameters + mc.logger.Error("quantity of registers is 0") + return + } + + if quantity > 123 { + err = ErrUnexpectedParameters + mc.logger.Error("quantity of registers exceeds 123") + return + } + + if uint32(addr)+uint32(quantity)-1 > 0xffff { + err = ErrUnexpectedParameters + mc.logger.Error("end register address is past 0xffff") + return + } + + // create and fill in the request object + req = &pdu{ + unitId: mc.unitId, + functionCode: fcWriteMultipleRegisters, + } + + // base address + req.payload = uint16ToBytes(BIG_ENDIAN, addr) + // quantity of registers (2 bytes per register) + req.payload = append(req.payload, uint16ToBytes(BIG_ENDIAN, quantity)...) + // byte count + req.payload = append(req.payload, byte(payloadLength)) + // registers value + req.payload = append(req.payload, values...) + + // run the request across the transport and wait for a response + res, err = mc.executeRequest(req) + if err != nil { + return + } + + // validate the response code + switch { + case res.functionCode == req.functionCode: + // expect 4 bytes (2 byte of address + 2 bytes of quantity) + if len(res.payload) != 4 || + // bytes 1-2 should be the base register address + bytesToUint16(BIG_ENDIAN, res.payload[0:2]) != addr || + // bytes 3-4 should be the quantity of registers (2 bytes per register) + bytesToUint16(BIG_ENDIAN, res.payload[2:4]) != quantity { + err = ErrProtocolError + return + } + + case res.functionCode == (req.functionCode | 0x80): + if len(res.payload) != 1 { + err = ErrProtocolError + return + } + + err = mapExceptionCodeToError(res.payload[0]) + + default: + err = ErrProtocolError + mc.logger.Warningf("unexpected response code (%v)", res.functionCode) + } + + return +} + +func (mc *ModbusClient) executeRequest(req *pdu) (res *pdu, err error) { + // send the request over the wire, wait for and decode the response + res, err = mc.transport.ExecuteRequest(req) + if err != nil { + // map i/o timeouts to ErrRequestTimedOut + if os.IsTimeout(err) { + err = ErrRequestTimedOut + } + return + } + + // make sure the source unit id matches that of the request + if (res.functionCode&0x80) == 0x00 && res.unitId != req.unitId { + err = ErrBadUnitId + return + } + // accept errors from gateway devices (using special unit id #255) + if (res.functionCode&0x80) == 0x80 && + (res.unitId != req.unitId && res.unitId != 0xff) { + err = ErrBadUnitId + return + } + + return +} diff --git a/client_tls_test.go b/client_tls_test.go new file mode 100644 index 0000000..ed178f5 --- /dev/null +++ b/client_tls_test.go @@ -0,0 +1,584 @@ +package modbus + +import ( + "crypto/tls" + "crypto/x509" + "io" + "fmt" + "net" + "testing" + "time" +) + +const ( + // note: these certs and associated keys are self-signed + // and only meant to be used with this test. + // PLEASE DO NOT USE THEM FOR ANYTHING ELSE, EVER, as they + // do not provide any kind of security. + serverCert string = `-----BEGIN CERTIFICATE----- +MIIFkzCCA3ugAwIBAgIUWnvnN1r9czWyX7TGS+AwTNICd4wwDQYJKoZIhvcNAQEL +BQAwKTEnMCUGA1UEAwwebG9jYWxob3N0IFRFU1QgQ0VSVCBETyBOT1QgVVNFMB4X +DTIwMDgyMTA5NTEyMVoXDTQwMDgxNjA5NTEyMVowKTEnMCUGA1UEAwwebG9jYWxo +b3N0IFRFU1QgQ0VSVCBETyBOT1QgVVNFMIICIjANBgkqhkiG9w0BAQEFAAOCAg8A +MIICCgKCAgEA4D+a4wqvwxhyMhN4Z6EG7pIU1TfL7hV2MH4Izx1sDHGaUu+318SE +Egn85Zn1PbYAvYqlN+Ti3pCH5/tSJJHD4XVGcXtp3Wswt5MTXX8Ny3f1v3ZeQggp +nTy2tyODTulCQBg5L+8FTgJM2mJR0D+dryswiWgDVBLxg5W9p7icff30n/LHtEGd +jTkVbkaG798iGaIIeI6YS1wjMfsPWGWpG9SVoC3bkHN2NL2apecCLsoZpb+DiKdT +1rBG2pNeDseGpSWKwF/2/HeJsw+tD4okbtfYA7uURmRyqv1rxAXmclZXHFHpUL8l +Vt69g+ER0sXmLavM2Jj3iss6RF2MP6ghVUAcaciPbuDCn6+vnxCE6L2Gyr9G6Aur +rOBl/nRj3BHK9agp0fLzhIKgfCKMzCU5mo/UFlJIKbKIRJcdF5LNF3A9wD0K3Rv/ +2bIvaXdWwIgUQ+zX3V3cDuMatCs/F2jGE5FejGaNeA7ixfpdCtybBpzGewLB49NB +AIFBboJBdfW3QuqQBM32GFmbwM4cZpdxr97cZTDJh3Age7e8BSWPO195IJEKWNSn +bnWCDNG5J4G6MBf9AfC/ljJCrOIEN4wTXP6EF4vaMq/VWz674j7QvR9q9aKB1lNn +bdKd/LMH+jmgG8bGuy01Tj12/JBgzgG0KI72364wuJlTjneqkTpCncMCAwEAAaOB +sjCBrzAdBgNVHQ4EFgQUEFEWxaWofRhSTd2+ZWmaH14OKAswHwYDVR0jBBgwFoAU +EFEWxaWofRhSTd2+ZWmaH14OKAswDwYDVR0TAQH/BAUwAwEB/zAsBgNVHREEJTAj +gglsb2NhbGhvc3SHEAAAAAAAAAAAAAAAAAAAAAGHBH8AAAEwCwYDVR0PBAQDAgWg +MBMGA1UdJQQMMAoGCCsGAQUFBwMBMAwGA1UdEwEB/wQCMAAwDQYJKoZIhvcNAQEL +BQADggIBACYfWN0rK/OTXDZEQHLWLe6SDailqJSULVRhzxevo3i07MhsQCARofXH +A37b0q/GHhLsKUEsjB+TYip90eKaRV1XDgnMOdbQxVCKY5IVTmobnoj5ZLja1AgD +4I3CTxx3tqxLxRmV1Bre1tIqHnalA6kC6HJAKW+R0biPSccchIjW2ljB1SXBcS2T +RjAsIfbEl9sDhyLl8jaOaOBLPLS5PNgRs4XJ8/ps9dDCNyeOizGVzAsgaeTxXadC +Y505cxoWQR1atYerjqVr0XolCTkefOapsNJzH3YXF2mxKJCVxARv8Ns8e5WwIWw2 +r1ESi6M1qca5CutEgbBdUp7NyF44HJ9O3EsG+CFO6XRn6aaUvmin6vufKk29usRm +L3RWqBH1vz3vQzVLfEzXnJwnxDwZWcBrGx3RjKAL+O+hWHc3Qh6+AfI0yRX4j0MR +7IMHESf2xkCtw58w1t+OA1GBZ7hBX4zRiAQ89hk8UzRMw45yQ3cPkAp9u+PhrY1i +9dcDqvPueaSDoRMl7VvHyQ+2SeQF7mc3Xx6iAm9HPBmuVWVpX32g9jbu0xfzWhng +DXf3U5zg6BsG3gR5omPwbApKBlGckRY+ZuarhxPeczBx6KVIOKgvafybKrCsbso2 +oq2sBRSZveoEKZDOmZpsUP2jYrcgrybnurcoN6g1Chl28V5rNITd +-----END CERTIFICATE-----` + + serverKey string = `-----BEGIN PRIVATE KEY----- +MIIJRAIBADANBgkqhkiG9w0BAQEFAASCCS4wggkqAgEAAoICAQDgP5rjCq/DGHIy +E3hnoQbukhTVN8vuFXYwfgjPHWwMcZpS77fXxIQSCfzlmfU9tgC9iqU35OLekIfn ++1IkkcPhdUZxe2ndazC3kxNdfw3Ld/W/dl5CCCmdPLa3I4NO6UJAGDkv7wVOAkza +YlHQP52vKzCJaANUEvGDlb2nuJx9/fSf8se0QZ2NORVuRobv3yIZogh4jphLXCMx ++w9YZakb1JWgLduQc3Y0vZql5wIuyhmlv4OIp1PWsEbak14Ox4alJYrAX/b8d4mz +D60PiiRu19gDu5RGZHKq/WvEBeZyVlccUelQvyVW3r2D4RHSxeYtq8zYmPeKyzpE +XYw/qCFVQBxpyI9u4MKfr6+fEITovYbKv0boC6us4GX+dGPcEcr1qCnR8vOEgqB8 +IozMJTmaj9QWUkgpsohElx0Xks0XcD3APQrdG//Zsi9pd1bAiBRD7NfdXdwO4xq0 +Kz8XaMYTkV6MZo14DuLF+l0K3JsGnMZ7AsHj00EAgUFugkF19bdC6pAEzfYYWZvA +zhxml3Gv3txlMMmHcCB7t7wFJY87X3kgkQpY1KdudYIM0bkngbowF/0B8L+WMkKs +4gQ3jBNc/oQXi9oyr9VbPrviPtC9H2r1ooHWU2dt0p38swf6OaAbxsa7LTVOPXb8 +kGDOAbQojvbfrjC4mVOOd6qROkKdwwIDAQABAoICAQDD5IxHPbSgdyB6wityS2ak +zZPJVr6csr7WSaMkWo1iqXKodKRiplbA81yqrb1gNTecXBtMInRU/Gjcq9zr+THm +J+5rf+XQ+KxMEPzftfe1AIv6v0pD4KGJq9npTeqM6pNnLkH2r5Qwuy2rsCvMAWab ++Nyji+ssbIfx7MMKWujJ3yjs+MafnpolHfKsrIt/y6ocPkGsHtTHMCvGo4yaKeR6 +XVB/5s9g9pwSIneP6acsfHu/IPekTpececzLb+TAgGgMqCj3OF2n2jy94TnK02BU +O9WGHTy/6UuKN2sGiCjxRJ9ALAXm9bOGmXlwVRKezyXuS5/crnPAGRxDUH0Ntq+2 +B9Cpwd2YA2UO3aw2w1fcVhdi+CYBNNSfnWdksRNfUH02g0EwITz28Onm69pJv3ze +6y4Vm9ZVksJmC6HJ0OzwMmqvDnK8aqN0jSUlhUeJOmVkWyJL5JFH0L2hHyadWOrX +EU9HORiznkMzcubcaexFnyBvwlmeordR2V94aQpkAE1zJT5YHH4YStE7qGStU+8S +kOikBytsY+SGe68OYUBdZyVpCx43b0c3XiXYkazxRN6GtMsTJh+1R8pg6DkIarj2 +HVZZotQS0ldkJkYSOpvUkAdy6mV3KfKvYhi0QGRFjMwD5OFhH2vX7kbgOtkKCCSb +fjSCsz2kEQyuNb4BIsLLkQKCAQEA/WibivWpORzrI+rLjQha3J7IfzaeWQN4l2G5 +Y/qrAWdYpuiZM3fkVHoo6Zg7uZaGxY47JAxWNAMNl/k2oqh7GKKNy2cK6xSvA/sP +MWgzQlvTqj6gewIDW7APiJVnmEtwOkkEsBGdty5t+68VNITXHO2HgwbJWgMd+Ou0 +2/bmkpPVEqKqIOqbgfDEKJkUK5HvM4wFK5fFYv/iIz5RhTFhlUBVO3RQtPjs735v +2dd+KXND+YZZrxCTv1wBFaZ3T27JWEq4JZhk7W0Y6JiYavN2quDHqfztaXDXmdv3 +FO0XnjSJ8U4rehNuuWX4+hx9JmAzN2wqKQAfaYamHnuR4Ob53wKCAQEA4oqo5C4h +xAc/d4Q8e3h6P5SgVTgNaGbz0oFimfv+OO0qJ2GQKomV1WAbqMatnwERoCnlZy84 +BSt3RYGY5arH7zU81LR8xKS7w4teBwU6x8CVGpn+UL/3ARCcueFyEohtt0RawOcr +IaXdrYSwjHnQr5qjxDrYGG5z+2/ynZzcKWvWAI789MJ9T/cnfsdBiKkW34KdLMnb +hlAfYPibs7CJdH9R2yXIYzobXihbkY4i7czCe3uoIoxkmmDFGJSo1WMZgFaoSlr/ +ltgFPyuvD9r0JHGynhMXXiCmWg/l5mZW6Lfuzb9LF7Znus3rbHFQcvLauSg9cxZT +hlmEMz7U/ZCgnQKCAQEAwNNx0GqgiyobL2iB3V5nLYvRiyO3mIpQn/inxpE+wMGw +Lsm9kfGAGFwgd6f0goMtKHTTQdn1WnycQnFLhrhnetZuyUEuiLVje8b1x6W/o5YW +WWxwV0mv3nv5RfhSLQvyaReY7pVpCrPU0vhmTWFsAsIoJKbsXocSrpBFPkABMbY2 +I4kNpiB/ln/r8+yP8ZuJhhLc+E/zziJiJGlOROjPlW+vq58Vrq/gM1llqUEV6lqg +deYqplEZ7DoJRT03eoUVxw6MU2dEHXqvwoYjLPb37I1AwXQJ//ryxEwiFpVXLHZU +JP9Ti//veDpFG6TEAoifUGQJLMvAG19vVrC2z4lSxwKCAQBjv/xX5Lw3bZ2TiaV8 +FHN3tYDXpUO6GcL4iMIa3Wt2M2+hQYNSR5yzBIuJSFpArh7NsETzp0X6eMYe0864 +Kfe5K27qlcJub77BfodbfgEA3ZqJyQ7DDZO8Y00vR8aLxIjS7oUrdV53hWpTsh5u +7GBoQiYkDGkEcPYe248vuVbz4iirvEpDl7PH1yML3r7LZvDMX93HT+aagIMglrcw +auZLZphrb3qJvpc4YXrYX4afwM5NwwgoljriAwQmK6cftnAPI5kcjG8IQ3wj8Z82 +0wk3Vtz4X52lc6jr9R4c0ikodXzwGW/+M/H+vhcQe+CZjLekWcSc/VKv0JC2Y88z +C1C9AoIBAQCKqMG7SsuH0E6qqq/vhfTHLLZVjnTBXigJKamZEwOiKq6Ib6xOPei6 +A9FugwAc10xdDS7AUy0EsPUUWBzFhLpjQO+CWPxcxA+ia35pKbfFjdy5DtOns736 +6Q1l8HT2JQw1siYGB+P3zyffpAuzYZ/ieaAoivwvuU0TRSjPEbljk8NCQBK0BNas +8pLBIe6ht7vcFsBiZyHTtBNSWZPkLz4HRGBGaaxPHernWsV4HtZlI64SsAa9n7Kz +2F7OMs1XatPrO+zwtx3xDB6iQYqCfzOfTNrq0fSwythyUQ29frvOLmJXBf2D2Wkj +yAqUh6zMzzcee67KOWWZMTuPQuu1n/m1 +-----END PRIVATE KEY-----` + + clientCert string = ` +-----BEGIN CERTIFICATE----- +MIIFXzCCA0egAwIBAgIUQzQeLPGsr6OmD5NtrsiYMFVwATwwDQYJKoZIhvcNAQEL +BQAwJjEkMCIGA1UEAwwbVEVTVCBDTElFTlQgQ0VSVCBETyBOT1QgVVNFMB4XDTIw +MDgyMTA5NTI0NVoXDTQwMDgxNjA5NTI0NVowJjEkMCIGA1UEAwwbVEVTVCBDTElF +TlQgQ0VSVCBETyBOT1QgVVNFMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKC +AgEA5yBIvheS8d2T8lBn9BOsdi1mMmhUHyqxdx9YFgwIV0NYb9s3/J83Jf/Focob +DM4fdy7iuSECX8/KUoymQmn26ivzmI4iLJ0LsBbUhTzMO9lo82Vg18E4Ab4GQVMz +LWcxnt1wt2EYJ5nq72c1h9K27pIecDDZ9DtBD1j0d3cuoo8HIVzsUfZRp80H9b5H +WuY2nMPZC3jp6HlsVSHCbkuscs/d7dDGK4tsanmyqVfBNmJhNKvo2GwMM9vf82li +dh6OwrqUNeMXkDU8vQ/xMGWfZ+Xpu0vXx3pQ5SX3WVDZFMuk7/mMBZwUhnXh+yjC +R1MVIRJvjijnkFzSSXctoysl/Mrc3QW5QmPmoa8KVWL8pSc5oaMMdX9bXo9omPf1 +XlmjKMvEUu2IbUQcaDFtVAKzR0UGcYEFh/QCIu7WV0pA24DfiX3r0Lv1GHHB6X+3 ++zUH7ZcaajzVQM/crCB/VLQiZODg6EgFtx1woll5hES/I6l9Me7UKySlxY4wTjIk +k/cwIN6R0dHGDny30fIkOl0vseo6SKtIkApYfe2tAASatbNdpkJQbMjUkm1IlcYj +ZZdn7yssjLjmAWn6pz19GbW2sAGQ/D4k/rN2hOpCc3NSP2FMIeWwp6TUHUG2ZFck +79j/Fo2bPNTPiFWxUW6h6gWHDgZg8UFz3FXvaeUM0URFOjsCAwEAAaOBhDCBgTAd +BgNVHQ4EFgQUZkTFhQZ4vaia2hiIQrZnSfx2nOAwHwYDVR0jBBgwFoAUZkTFhQZ4 +vaia2hiIQrZnSfx2nOAwDwYDVR0TAQH/BAUwAwEB/zALBgNVHQ8EBAMCBaAwEwYD +VR0lBAwwCgYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADANBgkqhkiG9w0BAQsFAAOC +AgEAhDVloT4TqZL66A/N8GSbiAALYVM4VoQlaiYiNrtwcNBAKrveN/RJlVxvYC9f +ybsvHz+wf/UAkWsQqd7pacKMlctNYoavpooO2ketQErpz+Ysmb+yGlL5u9DwQI36 +bsw/sxZjA4uunEM3yySVZf2k5j97HzBhhv24dlgxCPyu5tOAj83bM2QLTc5H7/KZ +ZEhMcrXN0+QUI9np3WYPKAPMJNODSMGD8mMqpjRufxDH0jhPhX4R4qvhHT+/OrLE +CwLTwtgZ8BnRS2b16QEGpvT7bu5EWZda4vgXQEeuMpEgUmwPOm2JS9QZguXrhA6u +Jd/12gbNEowQCt0qig1K2/ouYc3YKvCq/GuDPZnVq0nXEgSom4+g4UpU92zHARSy +CjfEW+rD9ay0ipzl6wxV09ZoQOoFwztf/AO89gl2CDtcw1J+mB8KcP2Pme+lWZ9m +mj7+ed+lubE5kBIK/H2EojEUceGmdluqD/T6bUaAR6edLuS0z4MKFTNlbbZq9QiS +vb6vr137SqCw56gFvYzxxOS2037QHAHk9dZz4+ik6BLXOQmHY1s59y/iAV3CrWwf +wVi6BS05QtOQW1nzeUU4DyMz4aAuBs88iGqDlipzkMreyYTG/66WpKCp/nezSn5H +cufNpBGKcE0Ww/H/GgMvKe/nB7HEJQqoAxVDeq75WFiHQrs= +-----END CERTIFICATE----- +` + clientKey string = `-----BEGIN PRIVATE KEY----- +MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQDnIEi+F5Lx3ZPy +UGf0E6x2LWYyaFQfKrF3H1gWDAhXQ1hv2zf8nzcl/8WhyhsMzh93LuK5IQJfz8pS +jKZCafbqK/OYjiIsnQuwFtSFPMw72WjzZWDXwTgBvgZBUzMtZzGe3XC3YRgnmerv +ZzWH0rbukh5wMNn0O0EPWPR3dy6ijwchXOxR9lGnzQf1vkda5jacw9kLeOnoeWxV +IcJuS6xyz93t0MYri2xqebKpV8E2YmE0q+jYbAwz29/zaWJ2Ho7CupQ14xeQNTy9 +D/EwZZ9n5em7S9fHelDlJfdZUNkUy6Tv+YwFnBSGdeH7KMJHUxUhEm+OKOeQXNJJ +dy2jKyX8ytzdBblCY+ahrwpVYvylJzmhowx1f1tej2iY9/VeWaMoy8RS7YhtRBxo +MW1UArNHRQZxgQWH9AIi7tZXSkDbgN+JfevQu/UYccHpf7f7NQftlxpqPNVAz9ys +IH9UtCJk4ODoSAW3HXCiWXmERL8jqX0x7tQrJKXFjjBOMiST9zAg3pHR0cYOfLfR +8iQ6XS+x6jpIq0iQClh97a0ABJq1s12mQlBsyNSSbUiVxiNll2fvKyyMuOYBafqn +PX0ZtbawAZD8PiT+s3aE6kJzc1I/YUwh5bCnpNQdQbZkVyTv2P8WjZs81M+IVbFR +bqHqBYcOBmDxQXPcVe9p5QzRREU6OwIDAQABAoICABiXYcYAAh2D4uLsVTMuCLKG +QBJq8VBjnYA8MIYf/58xRi6Yl4tkcVy0qxV8yIYDRGvM7EigT31cQX2pA2ObnK7r +wD5iGRbAGudAdpo6jsxrZHRJPBWYtFnTGx1GOfLBwRDTJNQOG6DTCqEwTQzHibk2 +iNCNEhOfXlvArjorzyVyrGKLXYWW/Lcq5IbsGPF9/x+M4wIKenDGwpUIQ4SyvoV0 +wns0NHGbowxtKGpGMQOVUhxlkh+810uJQHnIo7ZHqA7mBTD6mZ45W94N3S62EVDf +sI/CERJjXEoVUQ0Kwh4pUMJLve823SQ1VLcBbjJij6P2LzJj/cdpaOJyMMPkqmTY +cRUxtM/n5TQ9DVI867BKvDz/TplGaYKFEW1pmMZ2fH2w+YT/gZY8+YkVQsPVuj6c +sedxoAF6fUP4t/ROZkibyMyTJ4v2wF+tjugXddOM5DN9C4BuYJJgZ8tDWy+nf5Oe +weik6cheBXLYJd/LjZop1s+2pGe5/EjDiI16jdoVCwdPKTjNNjeopZr/Od0/C8jj +mljOYyf0wqrQsEBrmOxCtL0QL7gDg6kjYEWR2zbYiAoQu/iDZgdSb0V1RqAagiJP +qLMILSPDi8KHAh+k8z7JjSBXvhffSMtP+zI4iKfibTVAiLw5tUGzDuUrehgysSK7 +n9ETpNwwQZuQLx26KJrZAoIBAQD8IlVicwTDjKndJ4an1TOrUCU9qe6OKTO5RQoS +Jma+qOdrcWBIttbWZ8K9CcJUcdhImnp1Es+GTnMqCP2UG9UMysJAZeM3d5uwlpe3 +8s6Ju9x2n3yywCzSKgO0mragkdOUEyf+dHF6LgC098EsrfiDchvKinHgSpcFqmSB +1lC9QgyeikvmlSbwJcVWQXWN7dnP4dXL1j6ej6WB9ITKa7cztiBNWmAGAhJE8GUG +tmLG/zk7MV9cCzzJ7a2k9a63C6EGR3b4svs+SFEx8IzRHvzBHH+qdCAdJhzuD7Yp +jASbHNZ0b1yBEJNYTSIUrygFi0+6Ol5AMQusJGGHo0UErrhlAoIBAQDqq32o2sJa +0BsOEyXhAgZgZP+WocMmsduWqfvTCoze+MfD/f6vTBJZ45A98vd3xUHw/CsyS/dM +vjTRQCa0QCflgm1LiwbNR50QhXZ63AJ1ob2+FvDZFrPXJJcykUdMY1UwQh3gHKk/ +VYsm/s8L+VDDZXQlrcDQyHxVvML9Pn1GASgtp0NOKy6YA+jR5VJEqhMo8tlGdvfO +vXF24QIwoCYgKe+62jbJ9OB0XLVG8QrA4jV8oKI+U32kUwcDQC5EOw29rq0lbOA5 +ig4ry5SmkrlKK6wveOZYjWo3JKKo/o/cJnmZtEYnHzQf6p6m0M9odNgULgm3zO+I +3nXjNNTIg+4fAoIBABI7XVdAH/EQA9x1FjyeoxzZL8g0uIZZHl9gSakkU7untQxE +54R6jDB20lMfGIlIri4Z1Y8PrCf3FkbM3aFPHenN45wKghKpuH1ddl0b1qmJBxkg +0UCPuu37kccGhPw5b0Y+2F6DBw2hs/ViEPrtHZJLtwy/VBq26hLDzn7BA5eb5hO0 +xmZHFMi6wnlJRHnd4CkzGGWj+WU31+z8xHlqrpWzrsRJK7ZjgfSwOW3x1FS1cesA +1/ds7JlhcXQDO/4KfjtZAZZcQuSvEAf/b/9TMU25hNXLjeLttZvVUQPSFycsP6mt +v8+pZi41bah3Pfqgp0Q9IkGcCk8JVnAbc0syYy0CggEATeNNedXh3DJmSG2ijOQX +Kbdb/asDErzFnWQd6RX/W6JG645KEfS1wo/9OBKEgIRANrP7wl3kXtxiu3EHZ5xD +obGAhSpHv6qdPvaNNIoBZvmf+I+0sNkQJ8BFTstZVslBZRsMv23D3vmNjgvUvKyr +Wa86tabN8H4ahnp4XYV4HtwTcdOqSy+Z72qcw83RWGj6owS3iOPDrCLEnihgibMd +9F726pWyyaU1Omnq4PjwEMUD67GFKBqeAQRtt2597LeNAAASB/HzGiXwPij71a2t +QijspXUDPzDwqAzI0D5tkSxT/+gNwL5ilpVQwx1bOdhOP6RoJVEnz83GYvsOBN+F +EQKCAQEAo9j9MG+VCz+loz4fUXIJjC63ypfuRfxTCAIBMn4HzohP8chEcQBlWLCH +t0WcguYnwsuxGR4Rhx02UZCx3qNxiroBZ9w1NqTk947ZjKuNzqI7IpIqvtJ18op6 +QgQu8piNkf0/etAO0e6IjbZe4WfJCeKsAqE4vCV43baaSiHN/0pfYi6LLJ2YmTF/ ++sYY43naHg3zQTL4JbL4c58ebe4ADj4wIdNJ+/H5JgQf6r14iNjpyc6BJOjFuPyx +EJHQKb6499HKFua3QuH/kA6Ogfm9o3Lnwx/VO1lPLFteTv1fBKK00C00SkmyIe1p +iaKCVjivzjP1s/q6adzOOZVlVwm7Xw== +-----END PRIVATE KEY----- +` +) + +// TestTCPOVerTLSClient tests the TLS layer of the modbus client. +func TestTCPoverTLSClient(t *testing.T) { + var err error + var client *ModbusClient + var serverKeyPair tls.Certificate + var clientKeyPair tls.Certificate + var clientCp *x509.CertPool + var serverCp *x509.CertPool + var serverHostPort string + var serverChan chan string + var regs []uint16 + + serverChan = make(chan string) + + // load server and client keypairs + serverKeyPair, err = tls.X509KeyPair([]byte(serverCert), []byte(serverKey)) + if err != nil { + t.Errorf("failed to load test server key pair: %v", err) + return + } + + clientKeyPair, err = tls.X509KeyPair([]byte(clientCert), []byte(clientKey)) + if err != nil { + t.Errorf("failed to load test client key pair: %v", err) + return + } + + // start with an empty client cert pool initially to reject the server + // certificate + clientCp = x509.NewCertPool() + + // start with an empty server cert pool initially to reject the client + // certificate + serverCp = x509.NewCertPool() + + // start a mock modbus TLS server + go runMockTLSServer(t, serverKeyPair, serverCp, serverChan) + + // wait for the test server goroutine to signal its readiness + // and network location + serverHostPort = <-serverChan + + // attempt to create a client without specifying any TLS configuration + // parameter: should fail + client, err = NewClient(&ClientConfiguration{ + URL: fmt.Sprintf("tcp+tls://%s", serverHostPort), + }) + if err != ErrConfigurationError { + t.Errorf("NewClient() should have failed with %v, got: %v", + ErrConfigurationError, err) + } + + // attempt to create a client without specifying any TLS server + // cert/CA: should fail + client, err = NewClient(&ClientConfiguration{ + URL: fmt.Sprintf("tcp+tls://%s", serverHostPort), + TLSClientCert: &clientKeyPair, + }) + if err != ErrConfigurationError { + t.Errorf("NewClient() should have failed with %v, got: %v", + ErrConfigurationError, err) + } + + // attempt to create a client with both client cert+key and server + // cert/CA: should succeed + client, err = NewClient(&ClientConfiguration{ + URL: fmt.Sprintf("tcp+tls://%s", serverHostPort), + TLSClientCert: &clientKeyPair, + TLSRootCAs: clientCp, + }) + if err != nil { + t.Errorf("NewClient() should have succeeded, got: %v", err) + } + + // connect to the server: should fail with a TLS error as the server cert + // is not yet trusted by the client + err = client.Open() + if err == nil { + t.Errorf("Open() should have failed") + } + + // now load the server certificate into the client's trusted cert pool + // to get the client to accept the server's certificate + if !clientCp.AppendCertsFromPEM([]byte(serverCert)) { + t.Errorf("failed to load test server cert into cert pool") + } + + // connect to the server: should succeed + // note: client certificates are verified after the handshake procedure + // has completed, so Open() won't fail even though the client cert + // is rejected by the server. + // (see RFC 8446 section 4.6.2 Post Handshake Authentication) + err = client.Open() + if err != nil { + t.Errorf("Open() should have succeeded, got: %v", err) + } + + // attempt to read two registers: since the client cert won't pass + // the validation step yet (no cert in server cert pool), + // expect a tls error + regs, err = client.ReadRegisters(0x1000, 2, INPUT_REGISTER) + if err == nil { + t.Errorf("ReadRegisters() should have failed") + } + client.Close() + + // now place the client cert in the server's authorized client list + // to get the client cert past the validation procedure + if !serverCp.AppendCertsFromPEM([]byte(clientCert)) { + t.Errorf("failed to load test client cert into cert pool") + } + + // connect to the server: should succeed + err = client.Open() + if err != nil { + t.Errorf("Open() should have succeeded, got: %v", err) + } + + // attempt to read two registers: should succeed + regs, err = client.ReadRegisters(0x1000, 2, INPUT_REGISTER) + if err != nil { + t.Errorf("ReadRegisters() should have succeeded, got: %v", err) + } + if regs[0] != 0x1234 { + t.Errorf("expected 0x1234 in 1st reg, saw: 0x%04x", regs[0]) + } + if regs[1] != 0x5678 { + t.Errorf("expected 0x5678 in 2nd reg, saw: 0x%04x", regs[1]) + } + + // attempt to read another: should succeed + regs, err = client.ReadRegisters(0x1002, 1, HOLDING_REGISTER) + if err != nil { + t.Errorf("ReadRegisters() should have succeeded, got: %v", err) + } + if regs[0] != 0xaabb { + t.Errorf("expected 0xaabb in 1st reg, saw: 0x%04x", regs[0]) + } + + // close the connection: should succeed + err = client.Close() + if err != nil { + t.Errorf("Close() should have succeeded, got: %v", err) + } + + return +} + +func TestTLSClientOnServerTimeout(t *testing.T) { + var err error + var client *ModbusClient + var server *ModbusServer + var serverKeyPair tls.Certificate + var clientKeyPair tls.Certificate + var clientCp *x509.CertPool + var serverCp *x509.CertPool + var th *tlsTestHandler + var reg uint16 + + th = &tlsTestHandler{} + // load server and client keypairs + serverKeyPair, err = tls.X509KeyPair([]byte(serverCert), []byte(serverKey)) + if err != nil { + t.Errorf("failed to load test server key pair: %v", err) + return + } + + clientKeyPair, err = tls.X509KeyPair([]byte(clientCert), []byte(clientKey)) + if err != nil { + t.Errorf("failed to load test client key pair: %v", err) + return + } + + // add those keypairs to their corresponding cert pool + clientCp = x509.NewCertPool() + if !clientCp.AppendCertsFromPEM([]byte(serverCert)) { + t.Errorf("failed to load test server cert into cert pool") + } + + serverCp = x509.NewCertPool() + if !serverCp.AppendCertsFromPEM([]byte(clientCert)) { + t.Errorf("failed to load client cert into cert pool") + } + + + // load the server cert into the client CA cert pool to get the server cert + // accepted by clients + clientCp = x509.NewCertPool() + if !clientCp.AppendCertsFromPEM([]byte(serverCert)) { + t.Errorf("failed to load test server cert into cert pool") + } + + server, err = NewServer(&ServerConfiguration{ + URL: "tcp+tls://[::1]:5802", + MaxClients: 10, + TLSServerCert: &serverKeyPair, + TLSClientCAs: serverCp, + // disconnect idle clients after 500ms + Timeout: 500 * time.Millisecond, + }, th) + if err != nil { + t.Errorf("failed to create server: %v", err) + } + + err = server.Start() + if err != nil { + t.Errorf("failed to start server: %v", err) + } + + // create the modbus client + client, err = NewClient(&ClientConfiguration{ + URL: "tcp+tls://localhost:5802", + TLSClientCert: &clientKeyPair, + TLSRootCAs: clientCp, + }) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + + // connect to the server: should succeed + err = client.Open() + if err != nil { + t.Errorf("Open() should have succeeded, got: %v", err) + } + + // write a value to register #3: should succeed + err = client.WriteRegister(3, 0x0199) + if err != nil { + t.Errorf("Write() should have succeeded, got: %v", err) + } + + // attempt to read the value back: should succeed + reg, err = client.ReadRegister(3, HOLDING_REGISTER) + if err != nil { + t.Errorf("ReadRegisters() should have succeeded, got: %v", err) + } + if reg != 0x0199 { + t.Errorf("expected 0x0199 in reg #3, saw: 0x%04x", reg) + } + + // pause for longer than the server's configured timeout to end up with + // an open client with a closed underlying TCP socket + time.Sleep(1 * time.Second) + + // attempt a read: should fail + _, err = client.ReadRegister(3, INPUT_REGISTER) + if err == nil { + t.Errorf("ReadRegister() should have failed") + } + + // cleanup + client.Close() + server.Stop() + + return +} + +// runMockTLSServer spins a test TLS server for use with TestTCPoverTLSClient. +func runMockTLSServer(t *testing.T, serverKeyPair tls.Certificate, + serverCp *x509.CertPool, serverChan chan string) { + var err error + var listener net.Listener + var sock net.Conn + var reqCount uint + var clientCount uint + var buf []byte + + // let the OS pick an available port on the loopback interface + listener, err = tls.Listen("tcp", "localhost:0", &tls.Config{ + // the server will use serverKeyPair (key+cert) to + // authenticate to the client + Certificates: []tls.Certificate{serverKeyPair}, + // the server will use the certpool to authenticate the + // client-side cert + ClientCAs: serverCp, + // request client-side authentication and client cert validation + ClientAuth: tls.RequireAndVerifyClientCert, + }) + if err != nil { + t.Errorf("failed to start test server listener: %v", err) + } + defer listener.Close() + + // let the main test goroutine know which port the OS picked + serverChan <- listener.Addr().String() + + for err == nil { + // accept client connections + sock, err = listener.Accept() + if err != nil { + t.Errorf("failed to accept client conn: %v", err) + break + } + + // only proceed with clients passing the tls handshake + // note: this will reject any client whose cert does not pass the + // verification step + err = sock.(*tls.Conn).Handshake() + if err != nil { + sock.Close() + err = nil + continue + } + + clientCount++ + if clientCount > 2 { + t.Errorf("expected 2 client conns, saw: %v", clientCount) + } + + // expect MBAP (modbus/tcp) messages inside the TLS tunnel + for { + // expect 12 bytes per request + buf = make([]byte, 12) + + _, err = sock.Read(buf) + if err != nil { + // ignore EOF errors (clients disconnecting) + if err != io.EOF { + t.Errorf("failed to read client request: %v", err) + } + sock.Close() + break + } + + reqCount++ + switch reqCount { + case 1: + for i, b := range []byte{ + 0x00, 0x01, // txn id + 0x00, 0x00, // protocol id + 0x00, 0x06, // length + 0x01, 0x04, // unit id + function code + 0x10, 0x00, // start address + 0x00, 0x02, // quantity + } { + if b != buf[i] { + t.Errorf("expected 0x%02x at pos %v, saw 0x%02x", + b, i, buf[i]) + } + } + + // send a reply + _, err = sock.Write([]byte{ + 0x00, 0x01, // txn id + 0x00, 0x00, // protocol id + 0x00, 0x07, // length + 0x01, 0x04, // unit id + function code + 0x04, // byte count + 0x12, 0x34, // reg #0 + 0x56, 0x78, // reg #1 + }) + if err != nil { + t.Errorf("failed to write reply: %v", err) + } + + case 2: + for i, b := range []byte{ + 0x00, 0x02, // txn id + 0x00, 0x00, // protocol id + 0x00, 0x06, // length + 0x01, 0x03, // unit id + function code + 0x10, 0x02, // start address + 0x00, 0x01, // quantity + } { + if b != buf[i] { + t.Errorf("expected 0x%02x at pos %v, saw 0x%02x", + b, i, buf[i]) + } + } + + // send a reply + _, err = sock.Write([]byte{ + 0x00, 0x02, // txn id + 0x00, 0x00, // protocol id + 0x00, 0x05, // length + 0x01, 0x03, // unit id + function code + 0x02, // byte count + 0xaa, 0xbb, // reg #0 + }) + if err != nil { + t.Errorf("failed to write reply: %v", err) + } + + // stop the server after the 2nd request + listener.Close() + + default: + t.Errorf("unexpected request id %v", reqCount) + return + } + } + } +} diff --git a/cmd/modbus-cli.go b/cmd/modbus-cli.go new file mode 100644 index 0000000..dcff32f --- /dev/null +++ b/cmd/modbus-cli.go @@ -0,0 +1,1288 @@ +package main + +import ( + "crypto/tls" + "encoding/hex" + "errors" + "fmt" + "flag" + "os" + "strings" + "strconv" + "time" + + "github.com/simonvetter/modbus" +) + +func main() { + var err error + var help bool + var client *modbus.ModbusClient + var config *modbus.ClientConfiguration + var target string + var caPath string // path to TLS CA/server certificate + var certPath string // path to TLS client certificate + var keyPath string // path to TLS client key + var clientKeyPair tls.Certificate + var speed uint + var dataBits uint + var parity string + var stopBits uint + var endianness string + var wordOrder string + var timeout string + var cEndianess modbus.Endianness + var cWordOrder modbus.WordOrder + var unitId uint + var runList []operation + + flag.StringVar(&target, "target", "", "target device to connect to (e.g. tcp://somehost:502) [required]") + flag.UintVar(&speed, "speed", 19200, "serial bus speed in bps (rtu)") + flag.UintVar(&dataBits, "data-bits", 8, "number of bits per character on the serial bus (rtu)") + flag.StringVar(&parity, "parity", "none", "parity bit on the serial bus (rtu)") + flag.UintVar(&stopBits, "stop-bits", 2, "number of stop bits <0|1|2>) on the serial bus (rtu)") + flag.StringVar(&timeout, "timeout", "3s", "timeout value") + flag.StringVar(&endianness, "endianness", "big", "register endianness ") + flag.StringVar(&wordOrder, "word-order", "highfirst", "word ordering for 32-bit registers ") + flag.UintVar(&unitId, "unit-id", 1, "unit/slave id to use") + flag.StringVar(&certPath, "cert", "", "path to TLS client certificate") + flag.StringVar(&keyPath, "key", "", "path to TLS client key") + flag.StringVar(&caPath, "ca", "", "path to TLS CA/server certificate") + flag.BoolVar(&help, "help", false, "show a wall-of-text help message") + flag.Parse() + + if help { + displayHelp() + os.Exit(0) + } + + if target == "" { + fmt.Printf("no target specified, please use --target\n") + os.Exit(1) + } + + // create and populate the client configuration object + config = &modbus.ClientConfiguration{ + URL: target, + Speed: speed, + DataBits: dataBits, + StopBits: stopBits, + } + + switch parity { + case "none": config.Parity = modbus.PARITY_NONE + case "odd": config.Parity = modbus.PARITY_ODD + case "even": config.Parity = modbus.PARITY_EVEN + default: + fmt.Printf("unknown parity setting '%s' (should be one of none, odd or even)\n", + parity) + os.Exit(1) + } + + config.Timeout, err = time.ParseDuration(timeout) + if err != nil { + fmt.Printf("failed to parse timeout setting '%s': %v\n", timeout, err) + os.Exit(1) + } + + // parse encoding (endianness and word order) settings + switch endianness { + case "big": cEndianess = modbus.BIG_ENDIAN + case "little": cEndianess = modbus.LITTLE_ENDIAN + default: + fmt.Printf("unknown endianness setting '%s' (should either be big or little)\n", + endianness) + os.Exit(1) + } + + switch wordOrder { + case "highfirst", "hf": cWordOrder = modbus.HIGH_WORD_FIRST + case "lowfirst", "lf": cWordOrder = modbus.LOW_WORD_FIRST + default: + fmt.Printf("unknown word order setting '%s' (should be one of highfirst, hf, littlefirst, lf)\n", + wordOrder) + os.Exit(1) + } + + // handle TLS options + if strings.HasPrefix(target, "tcp+tls://") { + if certPath == "" { + fmt.Print("TLS requested but no client certificate given, please use --cert\n") + os.Exit(1) + } + + if keyPath == "" { + fmt.Print("TLS requested but no client key given, please use --key\n") + os.Exit(1) + } + + if caPath == "" { + fmt.Print("TLS requested but no CA/server cert given, please use --ca\n") + os.Exit(1) + } + + clientKeyPair, err = tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + fmt.Printf("failed to load client tls key pair: %v\n", err) + os.Exit(1) + } + config.TLSClientCert = &clientKeyPair + + config.TLSRootCAs, err = modbus.LoadCertPool(caPath) + if err != nil { + fmt.Printf("failed to load tls CA/server certificate: %v\n", err) + os.Exit(1) + } + } + + if len(flag.Args()) == 0 { + fmt.Printf("nothing to do.\n") + os.Exit(0) + } + + // parse arguments and build a list of objects + for _, arg := range flag.Args() { + var splitArgs []string + var o operation + + splitArgs = strings.Split(arg, ":") + if len(splitArgs) < 2 && splitArgs[0] != "repeat" && splitArgs[0] != "date" { + fmt.Printf("illegal command format (should be command:arg1:arg2..., e.g. rh:uint32:0x1000+5)\n") + os.Exit(2) + } + + switch splitArgs[0] { + case "rc", "readCoil", "readCoils", + "rdi", "readDiscreteInput", "readDiscreteInputs": + + if len(splitArgs) != 2 { + fmt.Printf("need exactly 1 argument after rc/rdi, got %v\n", + len(splitArgs) - 1) + os.Exit(2) + } + + if splitArgs[0] == "rc" || splitArgs[0] == "readCoil" || splitArgs[0] == "readCoils" { + o.isCoil = true + } + + o.op = readBools + o.addr, o.quantity, err = parseAddressAndQuantity(splitArgs[1]) + if err != nil { + fmt.Printf("failed to parse address ('%v'): %v\n", splitArgs[1], err) + os.Exit(2) + } + + case "rh", "readHoldingRegister", "readHoldingRegisters", + "ri", "readInputRegister", "readInputRegisters": + + if len(splitArgs) != 3 { + fmt.Printf("need exactly 2 arguments after rh/ri, got %v\n", + len(splitArgs) - 1) + os.Exit(2) + } + + if splitArgs[0] == "rh" || splitArgs[0] == "readHoldingRegister" || + splitArgs[0] == "readHoldingRegisters" { + o.isHoldingReg = true + } + + switch splitArgs[1] { + case "uint16": o.op = readUint16 + case "int16": o.op = readInt16 + case "uint32": o.op = readUint32 + case "int32": o.op = readInt32 + case "float32": o.op = readFloat32 + case "uint64": o.op = readUint64 + case "int64": o.op = readInt64 + case "float64": o.op = readFloat64 + case "bytes": o.op = readBytes + default: + fmt.Printf("unknown register type '%v' (should be one of " + + "[u]unt16, [u]int32, [u]int64, float32, float64, bytes)\n", + splitArgs[1]) + os.Exit(2) + } + + o.addr, o.quantity, err = parseAddressAndQuantity(splitArgs[2]) + if err != nil { + fmt.Printf("failed to parse address ('%v'): %v\n", splitArgs[2], err) + os.Exit(2) + } + + case "wc", "writeCoil": + if len(splitArgs) != 3 { + fmt.Printf("need exactly 2 arguments after writeCoil, got %v\n", + len(splitArgs) - 1) + os.Exit(2) + } + + o.op = writeCoil + o.addr, err = parseUint16(splitArgs[1]) + if err != nil { + fmt.Printf("failed to parse address ('%v'): %v\n", splitArgs[1], err) + os.Exit(2) + } + + switch splitArgs[2] { + case "true": o.coil = true + case "false": o.coil = false + default: + fmt.Printf("failed to parse coil value '%v' (should either be true or false)\n", + splitArgs[2]) + os.Exit(2) + } + + case "wr", "writeRegister": + if len(splitArgs) != 4 { + fmt.Printf("need exactly 3 arguments after writeRegister, got %v\n", + len(splitArgs) - 1) + os.Exit(2) + } + + o.addr, err = parseUint16(splitArgs[2]) + if err != nil { + fmt.Printf("failed to parse address ('%v'): %v\n", splitArgs[2], err) + os.Exit(2) + } + + switch splitArgs[1] { + case "uint16": + o.op = writeUint16 + o.u16, err = parseUint16(splitArgs[3]) + + case "int16": + o.op = writeInt16 + o.u16, err = parseInt16(splitArgs[3]) + + case "uint32": + o.op = writeUint32 + o.u32, err = parseUint32(splitArgs[3]) + + case "int32": + o.op = writeInt32 + o.u32, err = parseInt32(splitArgs[3]) + + case "float32": + o.op = writeFloat32 + o.f32, err = parseFloat32(splitArgs[3]) + + case "uint64": + o.op = writeUint64 + o.u64, err = parseUint64(splitArgs[3]) + + case "int64": + o.op = writeInt64 + o.u64, err = parseInt64(splitArgs[3]) + + case "float64": + o.op = writeFloat64 + o.f64, err = parseFloat64(splitArgs[3]) + + case "bytes": + o.op = writeBytes + o.bytes, err = parseHexBytes(splitArgs[3]) + + case "string": + o.op = writeBytes + o.bytes = []byte(splitArgs[3]) + err = nil + + default: + fmt.Printf("unknown register type '%v' (should be one of " + + "[u]unt16, [u]int32, [u]int64, float32, float64, bytes, string)\n", + splitArgs[1]) + os.Exit(2) + } + + if err != nil { + fmt.Printf("failed to parse '%s' as %s: %v\n", splitArgs[3], splitArgs[1], err) + os.Exit(2) + } + + case "sleep": + if len(splitArgs) != 2 { + fmt.Printf("need exactly 1 argument after sleep, got %v\n", + len(splitArgs) - 1) + os.Exit(2) + } + + o.op = sleep + o.duration, err = time.ParseDuration(splitArgs[1]) + if err != nil { + fmt.Printf("failed to parse '%s' as duration: %v\n", splitArgs[1], err) + os.Exit(2) + } + + case "suid", "setUnitId", "sid": + if len(splitArgs) != 2 { + fmt.Printf("need exactly 1 argument after setUnitId, got %v\n", + len(splitArgs) - 1) + os.Exit(2) + } + + o.op = setUnitId + o.unitId, err = parseUnitId(splitArgs[1]) + if err != nil { + fmt.Printf("failed to parse '%s' as unit id: %v\n", splitArgs[1], err) + os.Exit(2) + } + + case "repeat": + if len(splitArgs) != 1 { + fmt.Printf("repeat takes no arguments, got %v\n", + len(splitArgs) - 1) + os.Exit(2) + } + + o.op = repeat + + case "date": + if len(splitArgs) != 1 { + fmt.Printf("date takes no arguments, got %v\n", + len(splitArgs) - 1) + os.Exit(2) + } + + o.op = date + + case "scan": + if len(splitArgs) != 2 { + fmt.Printf("need exactly 1 argument after scan, got %v\n", + len(splitArgs) - 1) + os.Exit(2) + } + + switch splitArgs[1] { + case "c", "coils": + o.op = scanBools + o.isCoil = true + case "di", "discreteInputs": + o.op = scanBools + o.isCoil = false + case "h", "hr", "holding", "holdingRegisters": + o.op = scanRegisters + o.isHoldingReg = true + case "i", "ir", "input", "inputRegisters": + o.op = scanRegisters + o.isHoldingReg = false + case "s", "sid": + o.op = scanUnitId + default: + fmt.Printf("unknown scan/register type '%s' (valid options \n", + splitArgs[1]) + os.Exit(2) + } + + case "ping": + if len(splitArgs) < 2 || len(splitArgs) > 3 { + fmt.Printf("need 1 or 2 arguments after ping, got %v\n", + len(splitArgs) - 1) + os.Exit(2) + } + + o.op = ping + o.quantity, err = parseUint16(splitArgs[1]) + if err != nil { + fmt.Printf("failed to parse ping count ('%v'): %v\n", splitArgs[1], err) + os.Exit(2) + } + + if o.quantity == 0 { + fmt.Printf("illegal ping count value (must be >= 1)\n") + os.Exit(2) + } + + if len(splitArgs) == 3 { + o.duration, err = time.ParseDuration(splitArgs[2]) + if err != nil { + fmt.Printf("failed to parse '%s' as duration: %v\n", splitArgs[2], err) + os.Exit(2) + } + } + + default: + fmt.Printf("unsupported command '%v'\n", splitArgs[0]) + os.Exit(2) + } + + runList = append(runList, o) + + } + + // create the modbus client + client, err = modbus.NewClient(config) + if err != nil { + fmt.Printf("failed to create client: %v\n", err) + os.Exit(1) + } + + err = client.SetEncoding(cEndianess, cWordOrder) + if err != nil { + fmt.Printf("failed to set encoding: %v\n", err) + os.Exit(1) + } + + // set the initial unit id (note: this can be changed later at runtime through + // the setUnitId command) + if unitId > 0xff { + fmt.Printf("set unit id: value '%v' out of range\n", unitId) + os.Exit(1) + } + client.SetUnitId(uint8(unitId)) + + // connect to the remote host/open the serial port + err = client.Open() + if err != nil { + fmt.Printf("failed to open client: %v\n", err) + os.Exit(2) + } + + for opIdx := 0; opIdx < len(runList); opIdx++ { + var o *operation = &runList[opIdx] + + switch o.op { + case readBools: + var res []bool + + if o.isCoil { + res, err = client.ReadCoils(o.addr, o.quantity + 1) + } else { + res, err = client.ReadDiscreteInputs(o.addr, o.quantity + 1) + } + if err != nil { + fmt.Printf("failed to read coils/discrete inputs: %v\n", err) + } else { + for idx := range res { + fmt.Printf("0x%04x\t%-5v : %v\n", o.addr + uint16(idx), + o.addr + uint16(idx), + res[idx]) + } + } + + case readUint16, readInt16: + var res []uint16 + + if o.isHoldingReg { + res, err = client.ReadRegisters(o.addr, o.quantity + 1, modbus.HOLDING_REGISTER) + } else { + res, err = client.ReadRegisters(o.addr, o.quantity + 1, modbus.INPUT_REGISTER) + } + if err != nil { + fmt.Printf("failed to read holding/input registers: %v\n", err) + } else { + for idx := range res { + if o.op == readUint16 { + fmt.Printf("0x%04x\t%-5v : 0x%04x\t%v\n", + o.addr + uint16(idx), + o.addr + uint16(idx), + res[idx], res[idx]) + } else { + fmt.Printf("0x%04x\t%-5v : 0x%04x\t%v\n", + o.addr + uint16(idx), + o.addr + uint16(idx), + res[idx], int16(res[idx])) + } + } + } + + case readUint32, readInt32: + var res []uint32 + + if o.isHoldingReg { + res, err = client.ReadUint32s(o.addr, o.quantity + 1, modbus.HOLDING_REGISTER) + } else { + res, err = client.ReadUint32s(o.addr, o.quantity + 1, modbus.INPUT_REGISTER) + } + if err != nil { + fmt.Printf("failed to read holding/input registers: %v\n", err) + } else { + for idx := range res { + if o.op == readUint32 { + fmt.Printf("0x%04x\t%-5v : 0x%08x\t%v\n", + o.addr + (uint16(idx) * 2), + o.addr + (uint16(idx) * 2), + res[idx], res[idx]) + } else { + fmt.Printf("0x%04x\t%-5v : 0x%08x\t%v\n", + o.addr + (uint16(idx) * 2), + o.addr + (uint16(idx) * 2), + res[idx], int32(res[idx])) + } + } + } + + case readFloat32: + var res []float32 + + if o.isHoldingReg { + res, err = client.ReadFloat32s(o.addr, o.quantity + 1, modbus.HOLDING_REGISTER) + } else { + res, err = client.ReadFloat32s(o.addr, o.quantity + 1, modbus.INPUT_REGISTER) + } + if err != nil { + fmt.Printf("failed to read holding/input registers: %v\n", err) + } else { + for idx := range res { + fmt.Printf("0x%04x\t%-5v : %f\n", + o.addr + (uint16(idx) * 2), + o.addr + (uint16(idx) * 2), + res[idx]) + } + } + + case readUint64, readInt64: + var res []uint64 + + if o.isHoldingReg { + res, err = client.ReadUint64s(o.addr, o.quantity + 1, modbus.HOLDING_REGISTER) + } else { + res, err = client.ReadUint64s(o.addr, o.quantity + 1, modbus.INPUT_REGISTER) + } + if err != nil { + fmt.Printf("failed to read holding/input registers: %v\n", err) + } else { + for idx := range res { + if o.op == readUint64 { + fmt.Printf("0x%04x\t%-5v : 0x%016x\t%v\n", + o.addr + (uint16(idx) * 4), + o.addr + (uint16(idx) * 4), + res[idx], res[idx]) + } else { + fmt.Printf("0x%04x\t%-5v : 0x%016x\t%v\n", + o.addr + (uint16(idx) * 4), + o.addr + (uint16(idx) * 4), + res[idx], int64(res[idx])) + } + } + } + + case readFloat64: + var res []float64 + + if o.isHoldingReg { + res, err = client.ReadFloat64s(o.addr, o.quantity + 1, modbus.HOLDING_REGISTER) + } else { + res, err = client.ReadFloat64s(o.addr, o.quantity + 1, modbus.INPUT_REGISTER) + } + if err != nil { + fmt.Printf("failed to read holding/input registers: %v\n", err) + } else { + for idx := range res { + fmt.Printf("0x%04x\t%-5v : %f\n", + o.addr + (uint16(idx) * 4), + o.addr + (uint16(idx) * 4), + res[idx]) + } + } + + case readBytes: + var res []byte + + if o.isHoldingReg { + res, err = client.ReadBytes(o.addr, o.quantity + 1, modbus.HOLDING_REGISTER) + } else { + res, err = client.ReadBytes(o.addr, o.quantity + 1, modbus.INPUT_REGISTER) + } + if err != nil { + fmt.Printf("failed to read holding/input registers: %v\n", err) + } else { + for idx := range res { + if (idx % 16) == 0 { + fmt.Printf("0x%04x\t%-5v : ", + o.addr + (uint16(idx/2)), o.addr + (uint16(idx/2))) + } + fmt.Printf("%02x", res[idx]) + + if (idx % 16) == 15 || idx == len(res) - 1 { + fmt.Printf(" <%s>\n", + decodeString(res[(idx/16*16):(idx/16*16)+(idx%16)+1])) + } else if (idx % 16) == 7 { + fmt.Printf(" ") + } + } + } + + case writeCoil: + err = client.WriteCoil(o.addr, o.coil) + if err != nil { + fmt.Printf("failed to write %v at coil address 0x%04x: %v\n", + o.coil, o.addr, err) + } else { + fmt.Printf("wrote %v at coil address 0x%04x\n", + o.coil, o.addr) + } + + case writeUint16: + err = client.WriteRegister(o.addr, o.u16) + if err != nil { + fmt.Printf("failed to write %v at register address 0x%04x: %v\n", + o.u16, o.addr, err) + } else { + fmt.Printf("wrote %v at register address 0x%04x\n", + o.u16, o.addr) + } + + case writeInt16: + err = client.WriteRegister(o.addr, o.u16) + if err != nil { + fmt.Printf("failed to write %v at register address 0x%04x: %v\n", + int16(o.u16), o.addr, err) + } else { + fmt.Printf("wrote %v at register address 0x%04x\n", + int16(o.u16), o.addr) + } + + case writeUint32: + err = client.WriteUint32(o.addr, o.u32) + if err != nil { + fmt.Printf("failed to write %v at address 0x%04x: %v\n", + o.u32, o.addr, err) + } else { + fmt.Printf("wrote %v at address 0x%04x\n", + o.u32, o.addr) + } + + case writeInt32: + err = client.WriteUint32(o.addr, o.u32) + if err != nil { + fmt.Printf("failed to write %v at address 0x%04x: %v\n", + int32(o.u32), o.addr, err) + } else { + fmt.Printf("wrote %v at address 0x%04x\n", + int32(o.u32), o.addr) + } + + case writeFloat32: + err = client.WriteFloat32(o.addr, o.f32) + if err != nil { + fmt.Printf("failed to write %f at address 0x%04x: %v\n", + o.f32, o.addr, err) + } else { + fmt.Printf("wrote %f at address 0x%04x\n", + o.f32, o.addr) + } + + case writeUint64: + err = client.WriteUint64(o.addr, o.u64) + if err != nil { + fmt.Printf("failed to write %v at address 0x%04x: %v\n", + o.u64, o.addr, err) + } else { + fmt.Printf("wrote %v at address 0x%04x\n", + o.u64, o.addr) + } + + case writeInt64: + err = client.WriteUint64(o.addr, o.u64) + if err != nil { + fmt.Printf("failed to write %v at address 0x%04x: %v\n", + int64(o.u64), o.addr, err) + } else { + fmt.Printf("wrote %v at address 0x%04x\n", + int64(o.u64), o.addr) + } + + case writeFloat64: + err = client.WriteFloat64(o.addr, o.f64) + if err != nil { + fmt.Printf("failed to write %f at address 0x%04x: %v\n", + o.f64, o.addr, err) + } else { + fmt.Printf("wrote %f at address 0x%04x\n", + o.f64, o.addr) + } + + case writeBytes: + err = client.WriteBytes(o.addr, o.bytes) + if err != nil { + fmt.Printf("failed to write %v at address 0x%04x: %v\n", + o.bytes, o.addr, err) + } else { + fmt.Printf("wrote %v bytes at address 0x%04x\n", + len(o.bytes), o.addr) + } + + case sleep: + time.Sleep(o.duration) + + case setUnitId: + client.SetUnitId(o.unitId) + + case repeat: + // start over + opIdx = -1 + + case date: + fmt.Printf("%s\n", time.Now().Format(time.RFC3339)) + + case scanBools: + performBoolScan(client, o.isCoil) + + case scanRegisters: + performRegisterScan(client, o.isHoldingReg) + + case scanUnitId: + performUnitIdScan(client) + + case ping: + performPing(client, o.quantity, o.duration); + + default: + fmt.Printf("unknown operation %v\n", o) + os.Exit(100) + } + } + + return +} + +const ( + readBools uint = iota + 1 + readUint16 + readInt16 + readUint32 + readInt32 + readFloat32 + readUint64 + readInt64 + readFloat64 + readBytes + writeCoil + writeCoils + writeUint16 + writeInt16 + writeInt32 + writeUint32 + writeFloat32 + writeInt64 + writeUint64 + writeFloat64 + writeBytes + setUnitId + sleep + repeat + date + scanBools + scanRegisters + scanUnitId + ping +) + +type operation struct { + op uint + addr uint16 + isCoil bool + isHoldingReg bool + quantity uint16 + coil bool + u16 uint16 + u32 uint32 + f32 float32 + u64 uint64 + f64 float64 + bytes []byte + duration time.Duration + unitId uint8 +} + +func parseUint16(in string) (u16 uint16, err error) { + var val uint64 + + val, err = strconv.ParseUint(in, 0, 16) + if err == nil { + u16 = uint16(val) + return + } + + return +} + +func parseInt16(in string) (u16 uint16, err error) { + var val int64 + + val, err = strconv.ParseInt(in, 0, 16) + if err == nil { + u16 = uint16(int16(val)) + } + + return +} + +func parseUint32(in string) (u32 uint32, err error) { + var val uint64 + + val, err = strconv.ParseUint(in, 0, 32) + if err == nil { + u32 = uint32(val) + return + } + + return +} + +func parseInt32(in string) (u32 uint32, err error) { + var val int64 + + val, err = strconv.ParseInt(in, 0, 32) + if err == nil { + u32 = uint32(int32(val)) + } + + return +} + +func parseFloat32(in string) (f32 float32, err error) { + var val float64 + + val, err = strconv.ParseFloat(in, 32) + if err == nil { + f32 = float32(val) + } + + return +} + +func parseUint64(in string) (u64 uint64, err error) { + var val uint64 + + val, err = strconv.ParseUint(in, 0, 64) + if err == nil { + u64 = val + } + + return +} + +func parseInt64(in string) (u64 uint64, err error) { + var val int64 + + val, err = strconv.ParseInt(in, 0, 64) + if err == nil { + u64 = uint64(val) + } + + return +} + +func parseFloat64(in string) (f64 float64, err error) { + var val float64 + + val, err = strconv.ParseFloat(in, 64) + if err == nil { + f64 = val + } + + return +} + +func parseAddressAndQuantity(in string) (addr uint16, quantity uint16, err error) { + var split = strings.Split(in, "+") + + switch { + case len(split) == 1: + addr, err = parseUint16(in) + + case len(split) == 2: + addr, err = parseUint16(split[0]) + if err != nil { + return + } + quantity, err = parseUint16(split[1]) + default: + err = errors.New("illegal format") + } + + return +} + +func parseUnitId(in string) (addr uint8, err error) { + var val uint64 + + val, err = strconv.ParseUint(in, 0, 8) + if err == nil { + addr = uint8(val) + } + + return +} + +func parseHexBytes(in string) (out []byte, err error) { + out, err = hex.DecodeString(in) + + return +} + +func performBoolScan(client *modbus.ModbusClient, isCoil bool) { + var err error + var addr uint32 + var val bool + var count uint + var regType string + + if isCoil { + regType = "coil" + } else { + regType = "discrete input" + } + + fmt.Printf("starting %s scan\n", regType) + + for addr = 0; addr <= 0xffff; addr++ { + if isCoil { + val, err = client.ReadCoil(uint16(addr)) + } else { + val, err = client.ReadDiscreteInput(uint16(addr)) + } + if err == modbus.ErrIllegalDataAddress || err == modbus.ErrIllegalFunction { + // the register does not exist + continue + } else if err != nil { + fmt.Printf("failed to read %s at address 0x%04x: %v\n", + regType, addr, err) + } else { + // we found a coil: display its address and value + fmt.Printf("0x%04x\t%-5v : %v\n", addr, addr, val) + count++ + } + } + + fmt.Printf("found %v %ss\n", count, regType) + + return +} + +func performRegisterScan(client *modbus.ModbusClient, isHoldingReg bool) { + var err error + var addr uint32 + var val uint16 + var count uint + var regType string + + if isHoldingReg { + regType = "holding register" + } else { + regType = "input register" + } + + fmt.Printf("starting %s scan\n", regType) + + for addr = 0; addr <= 0xffff; addr++ { + if isHoldingReg { + val, err = client.ReadRegister(uint16(addr), modbus.HOLDING_REGISTER) + } else { + val, err = client.ReadRegister(uint16(addr), modbus.INPUT_REGISTER) + } + if err == modbus.ErrIllegalDataAddress || err == modbus.ErrIllegalFunction { + // the register does not exist + continue + } else if err != nil { + fmt.Printf("failed to read %s at address 0x%04x: %v\n", + regType, addr, err) + } else { + // we found a register: display its address and value + fmt.Printf("0x%04x\t%-5v : 0x%04x\t%v\n", + addr, addr, val, val) + count++ + } + } + + fmt.Printf("found %v %ss\n", count, regType) + + return +} + +func performUnitIdScan(client *modbus.ModbusClient) { + var err error + var countOk uint + var countErr uint + var countTimeout uint + var countGWTimeout uint + + fmt.Println("starting unit id scan") + + for unitId := uint(0); unitId <= 0xff; unitId++ { + client.SetUnitId(uint8(unitId)) + + _, err = client.ReadRegister(0, modbus.INPUT_REGISTER) + switch err { + case nil, + modbus.ErrIllegalDataAddress, + modbus.ErrIllegalFunction, + modbus.ErrIllegalDataValue: + fmt.Printf("0x%02x (%3v): ok\n", unitId, unitId) + countOk++ + + case modbus.ErrRequestTimedOut: + countTimeout++ + + case modbus.ErrGWTargetFailedToRespond: + countGWTimeout++ + + default: + fmt.Printf("0x%02x (%3v): %v\n", unitId, unitId, err) + countErr++ + } + } + + fmt.Printf("found %v devices (%v errors, %v timeouts, %v gateway timeouts)\n", + countOk, countErr, countTimeout, countGWTimeout) + + return +} + +func performPing(client *modbus.ModbusClient, count uint16, interval time.Duration) { + var err error + var okCount uint + var timeoutCount uint + var otherErrCount uint + var startTs time.Time + var ts time.Time + var rtt time.Duration + var minRTT time.Duration + var maxRTT time.Duration + var avgRTT time.Duration + + fmt.Printf("ping: sending %v requests...\n", count) + + startTs = time.Now() + + for run := uint16(0); run < count; run++ { + ts = time.Now() + _, err = client.ReadRegister(0x0000, modbus.HOLDING_REGISTER) + + rtt = time.Since(ts) + avgRTT += rtt + + if run == 0 || rtt < minRTT { + minRTT = rtt + } + + if rtt > maxRTT { + maxRTT = rtt + } + + switch err { + // mask illegal data address and illegal function errors since we + // only care about getting a response from the target device + // (on which holding reg #0 may or may not exist) + case nil, modbus.ErrIllegalDataAddress, modbus.ErrIllegalFunction: + okCount++ + fmt.Printf("ok: seq = %v, time: %v\n", + run + 1, rtt.Round(time.Microsecond)) + + case modbus.ErrRequestTimedOut, modbus.ErrGWTargetFailedToRespond: + timeoutCount++ + fmt.Printf("timeout (%v): seq = %v, time: %v\n", + err, run + 1, rtt.Round(time.Microsecond)) + + default: + otherErrCount++ + fmt.Printf("error (%v): seq = %v, time: %v\n", + err, run + 1, rtt.Round(time.Microsecond)) + } + + if interval > 0 { + time.Sleep(interval) + } + } + + fmt.Printf("--- ping statistics ---\n" + + "%v queries, %v target replies, %v transmission errors, %v timeouts, time: %v\n", + count, okCount, otherErrCount, timeoutCount, + time.Since(startTs).Round(time.Millisecond)) + + fmt.Printf("rtt min/avg/max = %v/%v/%v\n", + minRTT.Round(time.Microsecond), + (avgRTT / time.Duration(count)).Round(time.Microsecond), + maxRTT.Round(time.Microsecond)) + + return +} + +func decodeString(in []byte) (out string) { + var dec []byte + var b byte + + for idx := range in { + if in[idx] >= 0x20 && in[idx] <= 0x7e { + b = in[idx] + } else { + b = '.' + } + + dec = append(dec, b) + } + + out = string(dec) + + return +} + +func displayHelp() { + flag.CommandLine.SetOutput(os.Stdout) + + fmt.Println( +`This tool is a modbus command line interface client meant to allow quick and easy +interaction with modbus devices (e.g. for probing or troubleshooting). + +Available options:`) + flag.PrintDefaults() + fmt.Printf( +` + +Commands must be given as trailing arguments after any options. + +Example: modbus-cli --target=tcp://somehost:502 --timeout=3s rh:uint16:0x100+5 wc:12:true + Read 6 holding registers at address 0x100 then set the coil at address 12 to true + on modbus/tcp device somehost port 502, with a timeout of 3s. + +Available commands: +* :[+additional quantity] + Read coil at address , plus any additional coils if specified. + + rc:0x100+199 reads 200 coils starting at address 0x100 (hex) + rc:300 reads 1 coil at address 300 (decimal) + +* :[+additional quantity] + Read discrete input at address , plus any additional discrete inputs if specified. + + rdi:0x100+199 reads 200 discrete inputs starting at address 0x100 (hex) + rdi:300 reads 1 discrete input at address 300 (decimal) + +* ::[+additional quantity] + Read holding registers at address , plus any additional registers if specified, + decoded as which should be one of: + - uint16: unsigned 16-bit integer, + - int16: signed 16-bit integer, + - uint32: unsigned 32-bit integer (2 contiguous modbus registers), + - int32: signed 32-bit integer (2 contiguous modbus registers), + - float32: 32-bit floating point number (2 contiguous modbus registers), + - uint64: unsigned 64-bit integer (4 contiguous modbus registers), + - int64: signed 64-bit integer (4 contiguous modbus registers), + - float64: 64-bit floating point number (4 contiguous modbus registers), + - bytes: string of bytes (2 bytes per modbus register). + + rh:int16:0x300+1 reads 2 consecutive 16-bit signed integers at addresses 0x300 and 0x301 + rh:uint32:20 reads a 32-bit unsigned integer at addresses 20-21 (2 modbus registers) + rh:float32:500+10 reads 11 32-bit floating point numbers at addresses 500-521 + (11 * 32bit make for 22 16-bit contiguous modbus registers) + +* ::[+additional quanitity] + Read input registers at address , plus any additional registers if specified, decoded + in the same way as explained above. + + ri:uint16:0x300+1 reads 2 consecutive 16-bit unsigned integers at addresses 0x300 and 0x301 + ri:int32:20 reads a 32-bit signed integer at addresses 20-21 (2 modbus registers) + +* :: + Set the coil at address to either true or false, depending on . + + wc:1:true writes true to the coil at address 1 + wc:2:false writes false to the coil at address 2 + +* ::: + Write to register(s) at address , using the encoding given by . + + wr:int16:0xf100:-10 writes -10 as a 16-bit signed integer at address 0xf100 + (1 modbus register) + wr:int32:0xff00:0xff writes 0xff as a 32-bit signed integer at addresses 0xff00-0xff01 + (2 consecutive modbus registers) + wr:float64:100:-3.2 writes -3.2 as a 64-bit float at addresses 100-103 + (4 consecutive modbus registers) + wr:bytes:5:fafbfcfd writes 0xfafbfcfd as a 4-byte string at addresses 5-6 + (2 consecutive modbus registers) + +* sleep: + Pause for , specified as a golang duration string. + + sleep:300s sleeps for 300 seconds + sleep:3m sleeps for 3 minutes + sleep:3ms sleeps for 3 milliseconds + +* : + Switch to unit id (slave id) for subsequent requests. + + sid:10 selects unit id #10 + +* repeat + Restart execution of the given commands. + + rh:uint32:100 sleep:1s repeat reads a 32-bit unsigned integer at addresses 100-101 and + pauses for one second, forever in a loop. + +* date + Print the current date and time (can be useful for long-running scripts). + +* scan: + Perform a modbus "scan" of the modbus type , which can be one of: + - "c", "coils", + - "di", "discreteInputs", + - "hr", "holdingRegisters", + - "ir", "inputRegisters", + - "s", "sid". + + scan:hr scans the device for holding registers. + scan:di scans the device for discrete inputs. + + Read requests are made over the entire address space (65535 addresses). + Adresses for which a non-error response is received are listed, along with the value received. + Errors other than Illegal Data Address and Illegal Function are also shown, as they should + not happen in sane implementations. + + scan:sid scans the target for devices. + + Scans all unit IDs (0 to 255) using a single read input register request. Addresses responding + positively or with non-timeout errors are shown, while timeouts and gateway timeouts are ignored. + This command can be used to find active nodes on RS485 buses, behind gateways or in composite + devices. + +* ping:[:interval] + Executes modbus reads (1 holding register at address 0x0000), either back to back or + separated by [interval] if specified, then prints timing and outcome statistics. + This command can be used to troubleshoot network or serial connections. + +Register endianness and word order: + The endianness of holding/input registers can be specified with --endianness and + defaults to big endian (as per the modbus spec). + For constructs spanning multiple consecutive registers (namely [u]int32, float32, [u]int64 and + float64), the word order can be set with --word-order and arbitrarily + defaults to highfirst (i.e. most significant word first). + +Supported transports and associated target schemes: + - Modbus RTU using a local serial device: rtu:///path/to/device + - Modbus RTU over TCP (RTU framing over a TCP socket): rtuovertcp://host:port + - Modbus RTU over UDP (RTU framing over an UDP socket): rtuoverudp://host:port + - Modbus TCP (MBAP): tcp://host:port + - Modbus TCP over TLS (MBAPS or Modbus Security): tcp+tls://host:port + - Modbus TCP over UDP (MBAP over UDP): udp://host:port +Note that UDP transports are not part of the Modbus protocol specification. + +Examples: + $ modbus-cli --target tcp://10.100.0.10:502 rh:uint32:0x100+5 rc:0+10 wc:3:true + Connect to 10.100.0.10 port 502, read 6 consecutive 32-bit unsigned integers at addresses + 0x100-0x10b (12 modbus registers) and 11 coils at addresses 0-10, then set the coil at + address 3 to true. + + $ modbus-cli --target rtu:///dev/ttyUSB0 --speed 19200 suid:2 rh:uint16:0+7 \ + wr:uint16:0x2:0x0605 suid:3 ri:int16:0+1 sleep:1s repeat + Open serial port /dev/ttyUSB0 at a speed of 19200 bps and repeat forever: + select unit id (slave id) 2, read holding registers at addresses 0-7 as 16 bit unsigned + integers, write 0x605 as a 16-bit unsigned integer at address 2, + change for unit id 3, read input registers 0-1 as 16-bit signed integers, + pause for 1s. + + $ modbus-cli --target tcp://somehost:502 scan:hr scan:ir scan:di scan:coils + Connect to somehost port 502 and perform a scan of all modbus types (namely + holding registers, input registers, discrete inputs and coils). + + $ modbus-cli --target tcp+tls://securehost:802 --cert client.cert.pem --key client.key.pem \ + --ca ca.cert.pem rh:uint32:0x3000 + Connect to securehost port 802 using modbus/TCP over TLS, using client.cert.pem and + client.key.pem to authenticate to the server (client auth) and ca.cert.pem to authenticate + the server, then read holding registers 0x3000-0x3001 as a 32-bit unsigned integer. + Note that ca.cert.pem can either be a CA (Certificate Authority) or the server (leaf) + certificate. +`) + + return +} diff --git a/crc.go b/crc.go new file mode 100644 index 0000000..6c90e81 --- /dev/null +++ b/crc.go @@ -0,0 +1,73 @@ +package modbus + +var crcTable [256]uint16 = [256]uint16{ + 0x0000, 0xc0c1, 0xc181, 0x0140, 0xc301, 0x03c0, 0x0280, 0xc241, + 0xc601, 0x06c0, 0x0780, 0xc741, 0x0500, 0xc5c1, 0xc481, 0x0440, + 0xcc01, 0x0cc0, 0x0d80, 0xcd41, 0x0f00, 0xcfc1, 0xce81, 0x0e40, + 0x0a00, 0xcac1, 0xcb81, 0x0b40, 0xc901, 0x09c0, 0x0880, 0xc841, + 0xd801, 0x18c0, 0x1980, 0xd941, 0x1b00, 0xdbc1, 0xda81, 0x1a40, + 0x1e00, 0xdec1, 0xdf81, 0x1f40, 0xdd01, 0x1dc0, 0x1c80, 0xdc41, + 0x1400, 0xd4c1, 0xd581, 0x1540, 0xd701, 0x17c0, 0x1680, 0xd641, + 0xd201, 0x12c0, 0x1380, 0xd341, 0x1100, 0xd1c1, 0xd081, 0x1040, + 0xf001, 0x30c0, 0x3180, 0xf141, 0x3300, 0xf3c1, 0xf281, 0x3240, + 0x3600, 0xf6c1, 0xf781, 0x3740, 0xf501, 0x35c0, 0x3480, 0xf441, + 0x3c00, 0xfcc1, 0xfd81, 0x3d40, 0xff01, 0x3fc0, 0x3e80, 0xfe41, + 0xfa01, 0x3ac0, 0x3b80, 0xfb41, 0x3900, 0xf9c1, 0xf881, 0x3840, + 0x2800, 0xe8c1, 0xe981, 0x2940, 0xeb01, 0x2bc0, 0x2a80, 0xea41, + 0xee01, 0x2ec0, 0x2f80, 0xef41, 0x2d00, 0xedc1, 0xec81, 0x2c40, + 0xe401, 0x24c0, 0x2580, 0xe541, 0x2700, 0xe7c1, 0xe681, 0x2640, + 0x2200, 0xe2c1, 0xe381, 0x2340, 0xe101, 0x21c0, 0x2080, 0xe041, + 0xa001, 0x60c0, 0x6180, 0xa141, 0x6300, 0xa3c1, 0xa281, 0x6240, + 0x6600, 0xa6c1, 0xa781, 0x6740, 0xa501, 0x65c0, 0x6480, 0xa441, + 0x6c00, 0xacc1, 0xad81, 0x6d40, 0xaf01, 0x6fc0, 0x6e80, 0xae41, + 0xaa01, 0x6ac0, 0x6b80, 0xab41, 0x6900, 0xa9c1, 0xa881, 0x6840, + 0x7800, 0xb8c1, 0xb981, 0x7940, 0xbb01, 0x7bc0, 0x7a80, 0xba41, + 0xbe01, 0x7ec0, 0x7f80, 0xbf41, 0x7d00, 0xbdc1, 0xbc81, 0x7c40, + 0xb401, 0x74c0, 0x7580, 0xb541, 0x7700, 0xb7c1, 0xb681, 0x7640, + 0x7200, 0xb2c1, 0xb381, 0x7340, 0xb101, 0x71c0, 0x7080, 0xb041, + 0x5000, 0x90c1, 0x9181, 0x5140, 0x9301, 0x53c0, 0x5280, 0x9241, + 0x9601, 0x56c0, 0x5780, 0x9741, 0x5500, 0x95c1, 0x9481, 0x5440, + 0x9c01, 0x5cc0, 0x5d80, 0x9d41, 0x5f00, 0x9fc1, 0x9e81, 0x5e40, + 0x5a00, 0x9ac1, 0x9b81, 0x5b40, 0x9901, 0x59c0, 0x5880, 0x9841, + 0x8801, 0x48c0, 0x4980, 0x8941, 0x4b00, 0x8bc1, 0x8a81, 0x4a40, + 0x4e00, 0x8ec1, 0x8f81, 0x4f40, 0x8d01, 0x4dc0, 0x4c80, 0x8c41, + 0x4400, 0x84c1, 0x8581, 0x4540, 0x8701, 0x47c0, 0x4680, 0x8641, + 0x8201, 0x42c0, 0x4380, 0x8341, 0x4100, 0x81c1, 0x8081, 0x4040, +} + +type crc struct { + crc uint16 +} + +// Prepares the CRC generator for use. +func (c *crc) init() { + c.crc = 0xffff + + return +} + +// Adds the given bytes to the CRC. +func (c *crc) add(in []byte) { + var index byte + + for _, b := range in { + index = b ^ byte(c.crc & 0xff) + c.crc >>= 8 + c.crc ^= crcTable[index] + } + + return +} + +// Returns the CRC as two bytes, swapped. +func (c *crc) value() (value []byte) { + value = uint16ToBytes(LITTLE_ENDIAN, c.crc) + + return +} + +func (c *crc) isEqual(low byte, high byte) (yes bool) { + yes = (bytesToUint16(LITTLE_ENDIAN, []byte{low, high}) == c.crc) + + return +} diff --git a/crc_test.go b/crc_test.go new file mode 100644 index 0000000..771ec06 --- /dev/null +++ b/crc_test.go @@ -0,0 +1,106 @@ +package modbus + +import ( + "testing" +) + +func TestCRC(t *testing.T) { + var c crc + var out []byte + + // initialize the CRC object and make sure we get 0xffff as init value + c.init() + if c.crc != 0xffff { + t.Errorf("expected 0xffff, saw 0x%04x", c.crc) + } + + out = c.value() + if len(out) != 2 { + t.Errorf("value() should have returned 2 bytes, got %v", len(out)) + } + if out[0] != 0xff || out[1] != 0xff { + t.Errorf("expected {0xff, 0xff} got {0x%02x, 0x%02x}", out[0], out[1]) + } + + // add a few bytes, check the output + c.add([]byte{0x01, 0x02, 0x03, 0x04, 0x05}) + if c.crc != 0xbb2a { + t.Errorf("expected 0xbb2a, saw 0x%04x", c.crc) + } + + out = c.value() + if len(out) != 2 { + t.Errorf("value() should have returned 2 bytes, got %v", len(out)) + } + if out[0] != 0x2a || out[1] != 0xbb { + t.Errorf("expected {0x2a, 0xbb} got {0x%02x, 0x%02x}", out[0], out[1]) + } + + // add one extra byte, test the output again + c.add([]byte{0x06}) + if c.crc != 0xddba { + t.Errorf("expected 0xddba, saw 0x%04x", c.crc) + } + + out = c.value() + if len(out) != 2 { + t.Errorf("value() should have returned 2 bytes, got %v", len(out)) + } + if out[0] != 0xba || out[1] != 0xdd { + t.Errorf("expected {0xba, 0xdd} got {0x%02x, 0x%02x}", out[0], out[1]) + } + + // init the CRC once again: the output bytes should be back to 0xffff + c.init() + if c.crc != 0xffff { + t.Errorf("expected 0xffff, saw 0x%04x", c.crc) + } + + out = c.value() + if len(out) != 2 { + t.Errorf("value() should have returned 2 bytes, got %v", len(out)) + } + if out[0] != 0xff || out[1] != 0xff { + t.Errorf("expected {0xff, 0xff} got {0x%02x, 0x%02x}", out[0], out[1]) + } + + return +} + +func TestCRCIsEqual(t *testing.T) { + var c crc + var out []byte + + // initialize the CRC object and feed it a few bytes + c.init() + c.add([]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06}) + + // make sure the register value is what it should be + if c.crc != 0xddba { + t.Errorf("expected 0xddba, saw 0x%04x", c.crc) + } + + // positive test + if !c.isEqual(0xba, 0xdd) { + t.Error("isEqual() should have returned true") + } + + // negative test + if c.isEqual(0xdd, 0xba) { + t.Error("isEqual() should have returned false") + } + + // loopback test + out = c.value() + if !c.isEqual(out[0], out[1]) { + t.Error("isEqual() should have returned true") + } + + // an empty payload should have a CRC of 0xffff + c.init() + if !c.isEqual(0xff, 0xff) { + t.Error("isEqual() should have returned true") + } + + return +} diff --git a/encoding.go b/encoding.go new file mode 100644 index 0000000..ffedbd7 --- /dev/null +++ b/encoding.go @@ -0,0 +1,209 @@ +package modbus + +import ( + "encoding/binary" + "math" +) + +func uint16ToBytes(endianness Endianness, in uint16) (out []byte) { + out = make([]byte, 2) + switch endianness { + case BIG_ENDIAN: binary.BigEndian.PutUint16(out, in) + case LITTLE_ENDIAN: binary.LittleEndian.PutUint16(out, in) + } + + return +} + +func uint16sToBytes(endianness Endianness, in []uint16) (out []byte) { + for i := range in { + out = append(out, uint16ToBytes(endianness, in[i])...) + } + + return +} + +func bytesToUint16(endianness Endianness, in []byte) (out uint16) { + switch endianness { + case BIG_ENDIAN: out = binary.BigEndian.Uint16(in) + case LITTLE_ENDIAN: out = binary.LittleEndian.Uint16(in) + } + + return +} + +func bytesToUint16s(endianness Endianness, in []byte) (out []uint16) { + for i := 0; i < len(in); i += 2 { + out = append(out, bytesToUint16(endianness, in[i:i+2])) + } + + return +} + +func bytesToUint32s(endianness Endianness, wordOrder WordOrder, in []byte) (out []uint32) { + var u32 uint32 + + for i := 0; i < len(in); i += 4 { + switch endianness { + case BIG_ENDIAN: + if wordOrder == HIGH_WORD_FIRST { + u32 = binary.BigEndian.Uint32(in[i:i+4]) + } else { + u32 = binary.BigEndian.Uint32( + []byte{in[i+2], in[i+3], in[i+0], in[i+1]}) + } + case LITTLE_ENDIAN: + if wordOrder == LOW_WORD_FIRST { + u32 = binary.LittleEndian.Uint32(in[i:i+4]) + } else { + u32 = binary.LittleEndian.Uint32( + []byte{in[i+2], in[i+3], in[i+0], in[i+1]}) + } + } + + out = append(out, u32) + } + + return +} + +func uint32ToBytes(endianness Endianness, wordOrder WordOrder, in uint32) (out []byte) { + out = make([]byte, 4) + + switch endianness { + case BIG_ENDIAN: + binary.BigEndian.PutUint32(out, in) + + // swap words if needed + if wordOrder == LOW_WORD_FIRST { + out[0], out[1], out[2], out[3] = out[2], out[3], out[0], out[1] + } + case LITTLE_ENDIAN: + binary.LittleEndian.PutUint32(out, in) + + // swap words if needed + if wordOrder == HIGH_WORD_FIRST { + out[0], out[1], out[2], out[3] = out[2], out[3], out[0], out[1] + } + } + + return +} + +func bytesToFloat32s(endianness Endianness, wordOrder WordOrder, in []byte) (out []float32) { + var u32s []uint32 + + u32s = bytesToUint32s(endianness, wordOrder, in) + + for _, u32 := range u32s { + out = append(out, math.Float32frombits(u32)) + } + + return +} + +func float32ToBytes(endianness Endianness, wordOrder WordOrder, in float32) (out []byte) { + out = uint32ToBytes(endianness, wordOrder, math.Float32bits(in)) + + return +} + +func bytesToUint64s(endianness Endianness, wordOrder WordOrder, in []byte) (out []uint64) { + var u64 uint64 + + for i := 0; i < len(in); i += 8 { + switch endianness { + case BIG_ENDIAN: + if wordOrder == HIGH_WORD_FIRST { + u64 = binary.BigEndian.Uint64(in[i:i+8]) + } else { + u64 = binary.BigEndian.Uint64( + []byte{in[i+6], in[i+7], in[i+4], in[i+5], + in[i+2], in[i+3], in[i+0], in[i+1]}) + } + case LITTLE_ENDIAN: + if wordOrder == LOW_WORD_FIRST { + u64 = binary.LittleEndian.Uint64(in[i:i+8]) + } else { + u64 = binary.LittleEndian.Uint64( + []byte{in[i+6], in[i+7], in[i+4], in[i+5], + in[i+2], in[i+3], in[i+0], in[i+1]}) + } + } + + out = append(out, u64) + } + + return +} + +func uint64ToBytes(endianness Endianness, wordOrder WordOrder, in uint64) (out []byte) { + out = make([]byte, 8) + + switch endianness { + case BIG_ENDIAN: + binary.BigEndian.PutUint64(out, in) + + // swap words if needed + if wordOrder == LOW_WORD_FIRST { + out[0], out[1], out[2], out[3],out[4], out[5], out[6], out[7] = + out[6], out[7], out[4], out[5], out[2], out[3], out[0], out[1] + } + case LITTLE_ENDIAN: + binary.LittleEndian.PutUint64(out, in) + + // swap words if needed + if wordOrder == HIGH_WORD_FIRST { + out[0], out[1], out[2], out[3],out[4], out[5], out[6], out[7] = + out[6], out[7], out[4], out[5], out[2], out[3], out[0], out[1] + } + } + + return +} + +func bytesToFloat64s(endianness Endianness, wordOrder WordOrder, in []byte) (out []float64) { + var u64s []uint64 + + u64s = bytesToUint64s(endianness, wordOrder, in) + + for _, u64 := range u64s { + out = append(out, math.Float64frombits(u64)) + } + + return +} + +func float64ToBytes(endianness Endianness, wordOrder WordOrder, in float64) (out []byte) { + out = uint64ToBytes(endianness, wordOrder, math.Float64bits(in)) + + return +} + +func encodeBools(in []bool) (out []byte) { + var byteCount uint + var i uint + + byteCount = uint(len(in)) / 8 + if len(in) % 8 != 0 { + byteCount++ + } + + out = make([]byte, byteCount) + for i = 0; i < uint(len(in)); i++ { + if in[i] { + out[i/8] |= (0x01 << (i % 8)) + } + } + + return +} + +func decodeBools(quantity uint16, in []byte) (out []bool) { + var i uint + for i = 0; i < uint(quantity); i++ { + out = append(out, (((in[i/8] >> (i % 8)) & 0x01) == 0x01)) + } + + return +} diff --git a/encoding_test.go b/encoding_test.go new file mode 100644 index 0000000..65bd5b4 --- /dev/null +++ b/encoding_test.go @@ -0,0 +1,610 @@ +package modbus + +import ( + "testing" +) + +func TestUint16ToBytes(t *testing.T) { + var out []byte + + out = uint16ToBytes(BIG_ENDIAN, 0x4321) + if len(out) != 2 { + t.Errorf("expected 2 bytes, got %v", len(out)) + } + if out[0] != 0x43 || out[1] != 0x21 { + t.Errorf("expected {0x43, 0x21}, got {0x%02x, 0x%02x}", out[0], out[1]) + } + + out = uint16ToBytes(LITTLE_ENDIAN, 0x4321) + if len(out) != 2 { + t.Errorf("expected 2 bytes, got %v", len(out)) + } + if out[0] != 0x21 || out[1] != 0x43 { + t.Errorf("expected {0x21, 0x43}, got {0x%02x, 0x%02x}", out[0], out[1]) + } + + return +} + +func TestUint16sToBytes(t *testing.T) { + var out []byte + + out = uint16sToBytes(BIG_ENDIAN, []uint16{0x4321, 0x8765, 0xcba9}) + if len(out) != 6 { + t.Errorf("expected 6 bytes, got %v", len(out)) + } + if out[0] != 0x43 || out[1] != 0x21 { + t.Errorf("expected {0x43, 0x21}, got {0x%02x, 0x%02x}", out[0], out[1]) + } + if out[2] != 0x87 || out[3] != 0x65 { + t.Errorf("expected {0x87, 0x65}, got {0x%02x, 0x%02x}", out[0], out[1]) + } + if out[4] != 0xcb || out[5] != 0xa9 { + t.Errorf("expected {0xcb, 0xa9}, got {0x%02x, 0x%02x}", out[0], out[1]) + } + + out = uint16sToBytes(LITTLE_ENDIAN, []uint16{0x4321, 0x8765, 0xcba9}) + if len(out) != 6 { + t.Errorf("expected 6 bytes, got %v", len(out)) + } + if out[0] != 0x21 || out[1] != 0x43 { + t.Errorf("expected {0x21, 0x43}, got {0x%02x, 0x%02x}", out[0], out[1]) + } + if out[2] != 0x65 || out[3] != 0x87 { + t.Errorf("expected {0x65, 0x87}, got {0x%02x, 0x%02x}", out[0], out[1]) + } + if out[4] != 0xa9 || out[5] != 0xcb { + t.Errorf("expected {0xa9, 0xcb}, got {0x%02x, 0x%02x}", out[0], out[1]) + } + + return +} + +func TestBytesToUint16(t *testing.T) { + var result uint16 + + result = bytesToUint16(BIG_ENDIAN, []byte{0x43, 0x21}) + if result != 0x4321 { + t.Errorf("expected 0x4321, got 0x%04x", result) + } + + result = bytesToUint16(LITTLE_ENDIAN, []byte{0x43, 0x21}) + if result != 0x2143 { + t.Errorf("expected 0x2143, got 0x%04x", result) + } + + return +} + +func TestBytesToUint16s(t *testing.T) { + var results []uint16 + + results = bytesToUint16s(BIG_ENDIAN, []byte{0x11, 0x22, 0x33, 0x44}) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 0x1122 { + t.Errorf("expected 0x1122, got 0x%04x", results[0]) + } + if results[1] != 0x3344 { + t.Errorf("expected 0x3344, got 0x%04x", results[1]) + } + + results = bytesToUint16s(LITTLE_ENDIAN, []byte{0x11, 0x22, 0x33, 0x44}) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 0x2211 { + t.Errorf("expected 0x2211, got 0x%04x", results[0]) + } + if results[1] != 0x4433 { + t.Errorf("expected 0x4433, got 0x%04x", results[1]) + } + + return +} + +func TestUint32ToBytes(t *testing.T) { + var out []byte + + out = uint32ToBytes(BIG_ENDIAN, HIGH_WORD_FIRST, 0x87654321) + if len(out) != 4 { + t.Errorf("expected 4 bytes, got %v", len(out)) + } + if out[0] != 0x87 || out[1] != 0x65 || // first word + out[2] != 0x43 || out[3] != 0x21 { // second word + t.Errorf("expected {0x87, 0x65, 0x43, 0x21}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3]) + } + + out = uint32ToBytes(BIG_ENDIAN, LOW_WORD_FIRST, 0x87654321) + if len(out) != 4 { + t.Errorf("expected 4 bytes, got %v", len(out)) + } + if out[0] != 0x43 || out[1] != 0x21 || out[2] != 0x87 || out[3] != 0x65 { + t.Errorf("expected {0x43, 0x21, 0x87, 0x65}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3]) + } + + out = uint32ToBytes(LITTLE_ENDIAN, LOW_WORD_FIRST, 0x87654321) + if len(out) != 4 { + t.Errorf("expected 4 bytes, got %v", len(out)) + } + if out[0] != 0x21 || out[1] != 0x43 || out[2] != 0x65 || out[3] != 0x87 { + t.Errorf("expected {0x21, 0x43, 0x65, 0x87}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3]) + } + + out = uint32ToBytes(LITTLE_ENDIAN, HIGH_WORD_FIRST, 0x87654321) + if len(out) != 4 { + t.Errorf("expected 4 bytes, got %v", len(out)) + } + if out[0] != 0x65 || out[1] != 0x87 || out[2] != 0x21 || out[3] != 0x43 { + t.Errorf("expected {0x65, 0x87, 0x21, 0x43}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3]) + } + + return +} + +func TestBytesToUint32s(t *testing.T) { + var results []uint32 + + results = bytesToUint32s(BIG_ENDIAN, HIGH_WORD_FIRST, []byte{ + 0x87, 0x65, 0x43, 0x21, + 0x00, 0x11, 0x22, 0x33, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 0x87654321 { + t.Errorf("expected 0x87654321, got 0x%08x", results[0]) + } + if results[1] != 0x00112233 { + t.Errorf("expected 0x00112233, got 0x%08x", results[1]) + } + + results = bytesToUint32s(BIG_ENDIAN, LOW_WORD_FIRST, []byte{ + 0x87, 0x65, 0x43, 0x21, + 0x00, 0x11, 0x22, 0x33, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 0x43218765 { + t.Errorf("expected 0x43218765, got 0x%08x", results[0]) + } + if results[1] != 0x22330011 { + t.Errorf("expected 0x22330011, got 0x%08x", results[1]) + } + + results = bytesToUint32s(LITTLE_ENDIAN, LOW_WORD_FIRST, []byte{ + 0x87, 0x65, 0x43, 0x21, + 0x00, 0x11, 0x22, 0x33, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 0x21436587 { + t.Errorf("expected 0x21436587, got 0x%08x", results[0]) + } + if results[1] != 0x33221100 { + t.Errorf("expected 0x33221100, got 0x%08x", results[1]) + } + + results = bytesToUint32s(LITTLE_ENDIAN, HIGH_WORD_FIRST, []byte{ + 0x87, 0x65, 0x43, 0x21, + 0x00, 0x11, 0x22, 0x33, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 0x65872143 { + t.Errorf("expected 0x65872143, got 0x%08x", results[0]) + } + if results[1] != 0x11003322 { + t.Errorf("expected 0x11003322, got 0x%08x", results[1]) + } + + return +} + +func TestFloat32ToBytes(t *testing.T) { + var out []byte + + out = float32ToBytes(BIG_ENDIAN, HIGH_WORD_FIRST, 1.234) + if len(out) != 4 { + t.Errorf("expected 4 bytes, got %v", len(out)) + } + if out[0] != 0x3f || out[1] != 0x9d || out[2] != 0xf3 || out[3] != 0xb6 { + t.Errorf("expected {0x3f, 0x9d, 0xf3, 0xb6}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3]) + } + + out = float32ToBytes(BIG_ENDIAN, LOW_WORD_FIRST, 1.234) + if len(out) != 4 { + t.Errorf("expected 4 bytes, got %v", len(out)) + } + if out[0] != 0xf3 || out[1] != 0xb6 || out[2] != 0x3f || out[3] != 0x9d { + t.Errorf("expected {0xf3, 0xb6, 0x3f, 0x9d}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3]) + } + + out = float32ToBytes(LITTLE_ENDIAN, LOW_WORD_FIRST, 1.234) + if len(out) != 4 { + t.Errorf("expected 4 bytes, got %v", len(out)) + } + if out[0] != 0xb6 || out[1] != 0xf3 || out[2] != 0x9d || out[3] != 0x3f { + t.Errorf("expected {0xb6, 0xf3, 0x9d, 0x3f}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3]) + } + + out = float32ToBytes(LITTLE_ENDIAN, HIGH_WORD_FIRST, 1.234) + if len(out) != 4 { + t.Errorf("expected 4 bytes, got %v", len(out)) + } + if out[0] != 0x9d || out[1] != 0x3f || out[2] != 0xb6 || out[3] != 0xf3 { + t.Errorf("expected {0x9d, 0x3f, 0xb6, 0xf3}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3]) + } + + return +} + +func TestBytesToFloat32s(t *testing.T) { + var results []float32 + + results = bytesToFloat32s(BIG_ENDIAN, HIGH_WORD_FIRST, []byte{ + 0x3f, 0x9d, 0xf3, 0xb6, + 0x40, 0x49, 0x0f, 0xdb, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 1.234 { + t.Errorf("expected 1.234, got %.04f", results[0]) + } + if results[1] != 3.14159274101 { + t.Errorf("expected 3.14159274101, got %.09f", results[1]) + } + + results = bytesToFloat32s(BIG_ENDIAN, LOW_WORD_FIRST, []byte{ + 0xf3, 0xb6, 0x3f, 0x9d, + 0x0f, 0xdb, 0x40, 0x49, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 1.234 { + t.Errorf("expected 1.234, got %.04f", results[0]) + } + if results[1] != 3.14159274101 { + t.Errorf("expected 3.14159274101, got %.09f", results[1]) + } + + results = bytesToFloat32s(LITTLE_ENDIAN, LOW_WORD_FIRST, []byte{ + 0xb6, 0xf3, 0x9d, 0x3f, + 0xdb, 0x0f, 0x49, 0x40, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 1.234 { + t.Errorf("expected 1.234, got %.04f", results[0]) + } + if results[1] != 3.14159274101 { + t.Errorf("expected 3.14159274101, got %.09f", results[1]) + } + + results = bytesToFloat32s(LITTLE_ENDIAN, HIGH_WORD_FIRST, []byte{ + 0x9d, 0x3f, 0xb6, 0xf3, + 0x49, 0x40, 0xdb, 0x0f, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 1.234 { + t.Errorf("expected 1.234, got %.04f", results[0]) + } + if results[1] != 3.14159274101 { + t.Errorf("expected 3.14159274101, got %.09f", results[1]) + } + + return +} + +func TestUint64ToBytes(t *testing.T) { + var out []byte + + out = uint64ToBytes(BIG_ENDIAN, HIGH_WORD_FIRST, 0x0fedcba987654321) + if len(out) != 8 { + t.Errorf("expected 8 bytes, got %v", len(out)) + } + + if out[0] != 0x0f || out[1] != 0xed || // 1st word + out[2] != 0xcb || out[3] != 0xa9 || // 2nd word + out[4] != 0x87 || out[5] != 0x65 || // 3rd word + out[6] != 0x43 || out[7] != 0x21 { // 4th word + t.Errorf("expected {0x0f, 0xed, 0xcb, 0xa9, 0x87, 0x65, 0x43, 0x21}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3], out[4], out[5], out[6], out[7]) + } + + out = uint64ToBytes(BIG_ENDIAN, LOW_WORD_FIRST, 0x0fedcba987654321) + if len(out) != 8 { + t.Errorf("expected 8 bytes, got %v", len(out)) + } + if out[0] != 0x43 || out[1] != 0x21 || // 1st word + out[2] != 0x87 || out[3] != 0x65 || // 2nd word + out[4] != 0xcb || out[5] != 0xa9 || // 3rd word + out[6] != 0x0f || out[7] != 0xed { // 4th word + t.Errorf("expected {0x43, 0x21, 0x87, 0x65, 0xcb, 0xa9, 0x0f, 0xed}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3], out[4], out[5], out[6], out[7]) + } + + out = uint64ToBytes(LITTLE_ENDIAN, LOW_WORD_FIRST, 0x0fedcba987654321) + if len(out) != 8 { + t.Errorf("expected 8 bytes, got %v", len(out)) + } + if out[0] != 0x21 || out[1] != 0x43 || // 1st word + out[2] != 0x65 || out[3] != 0x87 || // 2nd word + out[4] != 0xa9 || out[5] != 0xcb || // 3rd word + out[6] != 0xed || out[7] != 0x0f { // 4th word + t.Errorf("expected {0x21, 0x43, 0x65, 0x87, 0xa9, 0xcb, 0xed, 0x0f}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3], out[4], out[5], out[6], out[7]) + } + + out = uint64ToBytes(LITTLE_ENDIAN, HIGH_WORD_FIRST, 0x0fedcba987654321) + if len(out) != 8 { + t.Errorf("expected 8 bytes, got %v", len(out)) + } + if out[0] != 0xed || out[1] != 0x0f || // 1st word + out[2] != 0xa9 || out[3] != 0xcb || // 2nd word + out[4] != 0x65 || out[5] != 0x87 || // 3rd word + out[6] != 0x21 || out[7] != 0x43 { // 4th word + t.Errorf("expected {0xed, 0x0f, 0xa9, 0xcb, 0x65, 0x87, 0x21, 0x43}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3], out[4], out[5], out[6], out[7]) + } + + return +} + +func TestBytesToUint64s(t *testing.T) { + var results []uint64 + + results = bytesToUint64s(BIG_ENDIAN, HIGH_WORD_FIRST, []byte{ + 0x0f, 0xed, 0xcb, 0xa9, 0x87, 0x65, 0x43, 0x21, + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 0x0fedcba987654321 { + t.Errorf("expected 0x0fedcba987654321, got 0x%016x", results[0]) + } + if results[1] != 0x0011223344556677 { + t.Errorf("expected 0x0011223344556677, got 0x%016x", results[1]) + } + + results = bytesToUint64s(BIG_ENDIAN, LOW_WORD_FIRST, []byte{ + 0x0f, 0xed, 0xcb, 0xa9, 0x87, 0x65, 0x43, 0x21, + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + + if results[0] != 0x43218765cba90fed { + t.Errorf("expected 0x43218765cba90fed, got 0x%016x", results[0]) + } + + if results[1] != 0x6677445522330011 { + t.Errorf("expected 0x6677445522330011, got 0x%016x", results[1]) + } + + results = bytesToUint64s(LITTLE_ENDIAN, LOW_WORD_FIRST, []byte{ + 0x0f, 0xed, 0xcb, 0xa9, 0x87, 0x65, 0x43, 0x21, + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 0x21436587a9cbed0f { + t.Errorf("expected 0x21436587a9cbed0f, got 0x%016x", results[0]) + } + if results[1] != 0x7766554433221100 { + t.Errorf("expected 0x7766554433221100, got 0x%016x", results[1]) + } + + results = bytesToUint64s(LITTLE_ENDIAN, HIGH_WORD_FIRST, []byte{ + 0x0f, 0xed, 0xcb, 0xa9, 0x87, 0x65, 0x43, 0x21, + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 0xed0fa9cb65872143 { + t.Errorf("expected 0xed0fa9cb65872143, got 0x%016x", results[0]) + } + if results[1] != 0x1100332255447766 { + t.Errorf("expected 0x1100332255447766, got 0x%016x", results[1]) + } + + return +} + +func TestFloat64ToBytes(t *testing.T) { + var out []byte + + out = float64ToBytes(BIG_ENDIAN, HIGH_WORD_FIRST, 1.2345678) + if len(out) != 8 { + t.Errorf("expected 8 bytes, got %v", len(out)) + } + if out[0] != 0x3f || out[1] != 0xf3 || out[2] != 0xc0 || out[3] != 0xca || + out[4] != 0x2a || out[5] != 0x5b || out[6] != 0x1d || out[7] != 0x5d { + t.Errorf("expected {0x3f, 0xf3, 0xc0, 0xca, 0x2a, 0x5b, 0x1d, 0x5d}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3], out[4], out[5], out[6], out[7]) + } + + out = float64ToBytes(BIG_ENDIAN, LOW_WORD_FIRST, 1.2345678) + if len(out) != 8 { + t.Errorf("expected 8 bytes, got %v", len(out)) + } + if out[0] != 0x1d || out[1] != 0x5d || out[2] != 0x2a || out[3] != 0x5b || + out[4] != 0xc0 || out[5] != 0xca || out[6] != 0x3f || out[7] != 0xf3 { + t.Errorf("expected {0x1d, 0x5d, 0x2a, 0x5b, 0xc0, 0xca, 0x3f, 0xf3}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3], out[4], out[5], out[6], out[7]) + } + + out = float64ToBytes(LITTLE_ENDIAN, LOW_WORD_FIRST, 1.2345678) + if len(out) != 8 { + t.Errorf("expected 8 bytes, got %v", len(out)) + } + + if out[0] != 0x5d || out[1] != 0x1d || out[2] != 0x5b || out[3] != 0x2a || + out[4] != 0xca || out[5] != 0xc0 || out[6] != 0xf3 || out[7] != 0x3f { + t.Errorf("expected {0x5d, 0x1d, 0x5b, 0x2a, 0xca, 0xc0, 0xf3, 0x3f}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3], out[4], out[5], out[6], out[7]) + } + + out = float64ToBytes(LITTLE_ENDIAN, HIGH_WORD_FIRST, 1.2345678) + if len(out) != 8 { + t.Errorf("expected 8 bytes, got %v", len(out)) + } + if out[0] != 0xf3 || out[1] != 0x3f || out[2] != 0xca || out[3] != 0xc0 || + out[4] != 0x5b || out[5] != 0x2a || out[6] != 0x5d || out[7] != 0x1d { + t.Errorf("expected {0xf3, 0x3f, 0xca, 0xc0, 0x5b, 0x2a, 0x5d, 0x1d}, got {0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x, 0x%02x}", + out[0], out[1], out[2], out[3], out[4], out[5], out[6], out[7]) + } + + return +} + +func TestBytesToFloat64s(t *testing.T) { + var results []float64 + + results = bytesToFloat64s(BIG_ENDIAN, HIGH_WORD_FIRST, []byte{ + 0x3f, 0xf3, 0xc0, 0xca, 0x2a, 0x5b, 0x1d, 0x5d, + 0x40, 0x09, 0x21, 0xfb, 0x5f, 0xff, 0xe9, 0x5e, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 1.2345678 { + t.Errorf("expected 1.2345678, got %.08f", results[0]) + } + if results[1] != 3.14159274101 { + t.Errorf("expected 3.14159274101, got %.09f", results[1]) + } + + results = bytesToFloat64s(BIG_ENDIAN, LOW_WORD_FIRST, []byte{ + 0x1d, 0x5d, 0x2a, 0x5b, 0xc0, 0xca, 0x3f, 0xf3, + 0xe9, 0x5e, 0x5f, 0xff, 0x21, 0xfb, 0x40, 0x09, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 1.2345678 { + t.Errorf("expected 1.234, got %.08f", results[0]) + } + if results[1] != 3.14159274101 { + t.Errorf("expected 3.14159274101, got %.09f", results[1]) + } + + results = bytesToFloat64s(LITTLE_ENDIAN, LOW_WORD_FIRST, []byte{ + 0x5d, 0x1d, 0x5b, 0x2a, 0xca, 0xc0, 0xf3, 0x3f, + 0x5e, 0xe9, 0xff, 0x5f, 0xfb, 0x21, 0x09, 0x40, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 1.2345678 { + t.Errorf("expected 1.234, got %.08f", results[0]) + } + if results[1] != 3.14159274101 { + t.Errorf("expected 3.14159274101, got %.09f", results[1]) + } + + results = bytesToFloat64s(LITTLE_ENDIAN, HIGH_WORD_FIRST, []byte{ + 0xf3, 0x3f, 0xca, 0xc0, 0x5b, 0x2a, 0x5d, 0x1d, + 0x09, 0x40, 0xfb, 0x21, 0xff, 0x5f, 0x5e, 0xe9, + }) + if len(results) != 2 { + t.Errorf("expected 2 values, got %v", len(results)) + } + if results[0] != 1.2345678 { + t.Errorf("expected 1.234, got %.08f", results[0]) + } + if results[1] != 3.14159274101 { + t.Errorf("expected 3.14159274101, got %.09f", results[1]) + } + + return +} + +func TestDecodeBools(t *testing.T) { + var results []bool + + results = decodeBools(1, []byte{0x01}) + if len(results) != 1 { + t.Errorf("expected 1 value, got %v", len(results)) + } + if results[0] != true { + t.Errorf("expected true, got false") + } + + results = decodeBools(1, []byte{0x0f}) + if len(results) != 1 { + t.Errorf("expected 1 value, got %v", len(results)) + } + if results[0] != true { + t.Errorf("expected true, got false") + } + + results = decodeBools(9, []byte{0x75, 0x03}) + if len(results) != 9 { + t.Errorf("expected 9 values, got %v", len(results)) + } + for i, b := range []bool{ + true, false, true, false, // 0x05 + true, true, true, false, // 0x07 + true, } { // 0x01 + if b != results[i] { + t.Errorf("expected %v at %v, got %v", b, i, results[i]) + } + } + + return +} + +func TestEncodeBools(t *testing.T) { + var results []byte + + results = encodeBools([]bool{false, true, false, true, }) + if len(results) != 1 { + t.Errorf("expected 1 byte, got %v", len(results)) + } + if results[0] != 0x0a { + t.Errorf("expected 0x0a, got 0x%02x", results[0]) + } + + results = encodeBools([]bool{true, false, true, }) + if len(results) != 1 { + t.Errorf("expected 1 byte, got %v", len(results)) + } + if results[0] != 0x05 { + t.Errorf("expected 0x05, got 0x%02x", results[0]) + } + + results = encodeBools([]bool{true, false, false, true, false, true, true, false, + true, true, true, false, true, true, true, false, + false, true}) + if len(results) != 3 { + t.Errorf("expected 3 bytes, got %v", len(results)) + } + if results[0] != 0x69 || results[1] != 0x77 || results[2] != 0x02 { + t.Errorf("expected {0x69, 0x77, 0x02}, got {0x%02x, 0x%02x, 0x%02x}", + results[0], results[1], results[2]) + } + + return +} diff --git a/examples/tcp_server.go b/examples/tcp_server.go new file mode 100644 index 0000000..08b8b91 --- /dev/null +++ b/examples/tcp_server.go @@ -0,0 +1,340 @@ +package main + +import ( + "fmt" + "os" + "math" + "sync" + "time" + + "github.com/simonvetter/modbus" +) + +const ( + MINUS_ONE int16 = -1 +) + +/* +* Simple modbus server example. +* +* This file is intended to be a demo of the modbus server. +* It shows how to create and start a server, as well as how +* to write a handler object. +* Feel free to use it as boilerplate for simple servers. +*/ + +// run this with go run examples/tcp_server.go +func main() { + var server *modbus.ModbusServer + var err error + var eh *exampleHandler + var ticker *time.Ticker + + // create the handler object + eh = &exampleHandler{} + + // create the server object + server, err = modbus.NewServer(&modbus.ServerConfiguration{ + // listen on localhost port 5502 + URL: "tcp://localhost:5502", + // close idle connections after 30s of inactivity + Timeout: 30 * time.Second, + // accept 5 concurrent connections max. + MaxClients: 5, + }, eh) + if err != nil { + fmt.Printf("failed to create server: %v\n", err) + os.Exit(1) + } + + // start accepting client connections + // note that Start() returns as soon as the server is started + err = server.Start() + if err != nil { + fmt.Printf("failed to start server: %v\n", err) + os.Exit(1) + } + + // increment a 32-bit uptime counter every second. + // (this counter is exposed as input registers 200-201 for demo purposes) + ticker = time.NewTicker(1 * time.Second) + for { + <-ticker.C + + // since the handler methods are called from multiple goroutines, + // use locking where appropriate to avoid concurrency issues. + eh.lock.Lock() + eh.uptime++ + eh.lock.Unlock() + } + + // never reached + return +} + +// Example handler object, passed to the NewServer() constructor above. +type exampleHandler struct { + // this lock is used to avoid concurrency issues between goroutines, as + // handler methods are called from different goroutines + // (1 goroutine per client) + lock sync.RWMutex + + // simple uptime counter, incremented in the main() above and exposed + // as a 32-bit input register (2 consecutive 16-bit modbus registers). + uptime uint32 + + // these are here to hold client-provided (written) values, for both coils and + // holding registers + coils [100]bool + holdingReg1 uint16 + holdingReg2 uint16 + + // this is a 16-bit signed integer + holdingReg3 int16 + + // this is a 32-bit unsigned integer + holdingReg4 uint32 +} + +// Coil handler method. +// This method gets called whenever a valid modbus request asking for a coil operation is +// received by the server. +// It exposes 100 read/writable coils at addresses 0-99, except address 80 which is +// read-only. +// (read them with ./modbus-cli --target tcp://localhost:5502 rc:0+99, write to register n +// with ./modbus-cli --target tcp://localhost:5502 wr:n:) +func (eh *exampleHandler) HandleCoils(req *modbus.CoilsRequest) (res []bool, err error) { + if req.UnitId != 1 { + // only accept unit ID #1 + // note: we're merely filtering here, but we could as well use the unit + // ID field to support multiple register maps in a single server. + err = modbus.ErrIllegalFunction + return + } + + // make sure that all registers covered by this request actually exist + if int(req.Addr) + int(req.Quantity) > len(eh.coils) { + err = modbus.ErrIllegalDataAddress + return + } + + // since we're manipulating variables shared between multiple goroutines, + // acquire a lock to avoid concurrency issues. + eh.lock.Lock() + // release the lock upon return + defer eh.lock.Unlock() + + // loop through `req.Quantity` registers, from address `req.Addr` to + // `req.Addr + req.Quantity - 1`, which here is conveniently `req.Addr + i` + for i := 0; i < int(req.Quantity); i++ { + // ignore the write if the current register address is 80 + if req.IsWrite && int(req.Addr) + i != 80 { + // assign the value + eh.coils[int(req.Addr) + i] = req.Args[i] + } + // append the value of the requested register to res so they can be + // sent back to the client + res = append(res, eh.coils[int(req.Addr) + i]) + } + + return +} + +// Discrete input handler method. +// Note that we're returning ErrIllegalFunction unconditionally. +// This will cause the client to receive "illegal function", which is the modbus way of +// reporting that this server does not support/implement the discrete input type. +func (eh *exampleHandler) HandleDiscreteInputs(req *modbus.DiscreteInputsRequest) (res []bool, err error) { + // this is the equivalent of saying + // "discrete inputs are not supported by this device" + // (try it with modbus-cli --target tcp://localhost:5502 rdi:1) + err = modbus.ErrIllegalFunction + + return +} + +// Holding register handler method. +// This method gets called whenever a valid modbus request asking for a holding register +// operation (either read or write) received by the server. +func (eh *exampleHandler) HandleHoldingRegisters(req *modbus.HoldingRegistersRequest) (res []uint16, err error) { + var regAddr uint16 + + if req.UnitId != 1 { + // only accept unit ID #1 + err = modbus.ErrIllegalFunction + return + } + + // since we're manipulating variables shared between multiple goroutines, + // acquire a lock to avoid concurrency issues. + eh.lock.Lock() + // release the lock upon return + defer eh.lock.Unlock() + + // loop through `quantity` registers + for i := 0; i < int(req.Quantity); i++ { + // compute the target register address + regAddr = req.Addr + uint16(i) + + switch regAddr { + // expose the static, read-only value of 0xff00 in register 100 + case 100: + res = append(res, 0xff00) + + // expose holdingReg1 in register 101 (RW) + case 101: + if req.IsWrite { + eh.holdingReg1 = req.Args[i] + } + res = append(res, eh.holdingReg1) + + // expose holdingReg2 in register 102 (RW) + case 102: + if req.IsWrite { + // only accept values 2 and 4 + switch req.Args[i] { + case 2, 4: + eh.holdingReg2 = req.Args[i] + + // make note of the change (e.g. for auditing purposes) + fmt.Printf("%s set reg#102 to %v\n", req.ClientAddr, eh.holdingReg2) + default: + // if the written value is neither 2 nor 4, + // return a modbus "illegal data value" to + // let the client know that the value is + // not acceptable. + err = modbus.ErrIllegalDataValue + return + } + } + res = append(res, eh.holdingReg2) + + // expose eh.holdingReg3 in register 103 (RW) + // note: eh.holdingReg3 is a signed 16-bit integer + case 103: + if req.IsWrite { + // cast the 16-bit unsigned integer passed by the server + // to a 16-bit signed integer when writing + eh.holdingReg3 = int16(req.Args[i]) + } + // cast the 16-bit signed integer from the handler to a 16-bit unsigned + // integer so that we can append it to `res`. + res = append(res, uint16(eh.holdingReg3)) + + + // expose the 16 most-significant bits of eh.holdingReg4 in register 200 + case 200: + if req.IsWrite { + eh.holdingReg4 = + ((uint32(req.Args[i]) << 16) & 0xffff0000 | + (eh.holdingReg4 & 0x0000ffff)) + } + res = append(res, uint16((eh.holdingReg4 >> 16) & 0x0000ffff)) + + // expose the 16 least-significant bits of eh.holdingReg4 in register 201 + case 201: + if req.IsWrite { + eh.holdingReg4 = + (uint32(req.Args[i]) & 0x0000ffff | + (eh.holdingReg4 & 0xffff0000)) + } + res = append(res, uint16(eh.holdingReg4 & 0x0000ffff)) + + + // any other address is unknown + default: + err = modbus.ErrIllegalDataAddress + return + } + } + + return +} + +// Input register handler method. +// This method gets called whenever a valid modbus request asking for an input register +// operation is received by the server. +// Note that input registers are always read-only as per the modbus spec. +func (eh *exampleHandler) HandleInputRegisters(req *modbus.InputRegistersRequest) (res []uint16, err error) { + var unixTs_s uint32 + var minusOne int16 = -1 + + if req.UnitId != 1 { + // only accept unit ID #1 + err = modbus.ErrIllegalFunction + return + } + + // get the current unix timestamp, converted as a 32-bit unsigned integer for + // simplicity + unixTs_s = uint32(time.Now().Unix() & 0xffffffff) + + // loop through all register addresses from req.addr to req.addr + req.Quantity - 1 + for regAddr := req.Addr; regAddr < req.Addr + req.Quantity; regAddr++ { + switch regAddr { + case 100: + // return the static value 0x1111 at address 100, as an unsigned + // 16-bit integer + // (read it with modbus-cli --target tcp://localhost:5502 ri:uint16:100) + res = append(res, 0x1111) + + case 101: + // return the static value -1 at address 101, as a signed 16-bit + // integer + // (read it with modbus-cli --target tcp://localhost:5502 ri:int16:101) + res = append(res, uint16(minusOne)) + + + // expose our uptime counter, encoded as a 32-bit unsigned integer in + // input registers 200-201 + // (read it with modbus-cli --target tcp://localhost:5502 ri:uint32:200) + case 200: + // return the 16 most significant bits of the uptime counter + // (using locking to avoid concurrency issues) + eh.lock.RLock() + res = append(res, uint16((eh.uptime >> 16) & 0xffff)) + eh.lock.RUnlock() + + case 201: + // return the 16 least significant bits of the uptime counter + // (again, using locking to avoid concurrency issues) + eh.lock.RLock() + res = append(res, uint16(eh.uptime & 0xffff)) + eh.lock.RUnlock() + + + // expose the current unix timestamp, encoded as a 32-bit unsigned integer + // in input registers 202-203 + // (read it with modbus-cli --target tcp://localhost:5502 ri:uint32:202) + case 202: + // return the 16 most significant bits of the current unix time + res = append(res, uint16((unixTs_s >> 16) & 0xffff)) + + case 203: + // return the 16 least significant bits of the current unix time + res = append(res, uint16(unixTs_s & 0xffff)) + + + // return 3.1415, encoded as a 32-bit floating point number in input + // registers 300-301 + // (read it with modbus-cli --target tcp://localhost:5502 ri:float32:300) + case 300: + // returh the 16 most significant bits of the number + res = append(res, uint16((math.Float32bits(3.1415) >> 16) & 0xffff)) + + case 301: + // returh the 16 least significant bits of the number + res = append(res, uint16((math.Float32bits(3.1415)) & 0xffff)) + + + // attempting to access any input register address other than + // those defined above will result in an illegal data address + // exception client-side. + default: + err = modbus.ErrIllegalDataAddress + return + } + } + + return +} diff --git a/examples/tls_client.go b/examples/tls_client.go new file mode 100644 index 0000000..34b990f --- /dev/null +++ b/examples/tls_client.go @@ -0,0 +1,92 @@ +package main + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "os" + + "github.com/simonvetter/modbus" +) + +/* + * Modbus client with TLS example. + * + * This file is intended to be a demo of the modbus client in TCP+TLS + * mode. It shows how to load certificates from files and how to + * configure the client to use them. + */ + +func main() { + var client *modbus.ModbusClient + var err error + var clientKeyPair tls.Certificate + var serverCertPool *x509.CertPool + var regs []uint16 + + // load the client certificate and its associated private key, which + // are used to authenticate the client to the server + clientKeyPair, err = tls.LoadX509KeyPair( + "certs/client.cert.pem", "certs/client.key.pem") + if err != nil { + fmt.Printf("failed to load client key pair: %v\n", err) + os.Exit(1) + } + + // load either the server certificate or the certificate of the CA + // (Certificate Authority) which signed the server certificate + serverCertPool, err = modbus.LoadCertPool("certs/server.cert.pem") + if err != nil { + fmt.Printf("failed to load server certificate/CA: %v\n", err) + os.Exit(1) + } + + // create a client targetting host secure-plc on port 802 using + // modbus TCP over TLS (MBAPS) + client, err = modbus.NewClient(&modbus.ClientConfiguration{ + // tcp+tls is the moniker for MBAPS (modbus/tcp encapsulated in + // TLS), + // 802/tcp is the IANA-registered port for MBAPS. + URL: "tcp+tls://secure-plc:802", + // set the client-side cert and key + TLSClientCert: &clientKeyPair, + // set the server/CA certificate + TLSRootCAs: serverCertPool, + }) + if err != nil { + fmt.Printf("failed to create modbus client: %v\n", err) + os.Exit(1) + } + + // now that the client is created and configured, attempt to connect + err = client.Open() + if err != nil { + fmt.Printf("failed to connect: %v\n", err) + os.Exit(2) + } + + // read two 16-bit holding registers at address 0x4000 + regs, err = client.ReadRegisters(0x4000, 2, modbus.HOLDING_REGISTER) + if err != nil { + fmt.Printf("failed to read registers 0x4000 and 0x4001: %v\n", err) + } else { + fmt.Printf("register 0x4000: 0x%04x\n", regs[0]) + fmt.Printf("register 0x4001: 0x%04x\n", regs[1]) + } + + // set register 0x4002 to 500 + err = client.WriteRegister(0x4002, 500) + if err != nil { + fmt.Printf("failed to write to register 0x4002: %v\n", err) + } else { + fmt.Printf("set register 0x4002 to 500\n") + } + + // close the connection + err = client.Close() + if err != nil { + fmt.Printf("failed to close connection: %v\n", err) + } + + os.Exit(0) +} diff --git a/examples/tls_server.go b/examples/tls_server.go new file mode 100644 index 0000000..94dfe44 --- /dev/null +++ b/examples/tls_server.go @@ -0,0 +1,255 @@ +package main + +import ( + "crypto/x509" + "crypto/tls" + "fmt" + "os" + "sync" + "time" + + "github.com/simonvetter/modbus" +) + +/* Modbus TCP+TLS (MBAPS or Modbus Security) server example. + * + * This file is intended to be a demo of the modbus server in a tcp+tls + * configuration. + * It shows how to configure and start a server, as well as how to use + * client roles to perform authorization in the handler. + * Feel free to use it as boilerplate for simple servers. + * + * This server simulates a simple wall clock device, exposing a 32-bit unix + * timestamp in holding registers #0 and 1. + * The timestamp is incremented every second by the main loop. + * + * Access control is done by way of Modbus Roles, which are encoded in the + * client certificate as an X509 extension: + * - any client can read the clock regardless of their role, provided that their + * certificate is accepted by the server, + * - only clients with the "operator" role specified in their certificate can + * set the time. + * + * Certificates with no, invalid or multiple Modbus Role extensions will have + * their role set to an empty string (req.ClientRole == ""). + * + * Requests from clients with certificates not passing TLS verification are + * rejected at the TLS layer (i.e. before reaching the Modbus layer). + * + * + * The following commands can be used to create self-signed server and client + * certificates: + * $ mkdir certs + * + * create the server key pair: + * $ openssl req -x509 -newkey rsa:4096 -sha256 -days 360 -nodes \ + * -keyout certs/server.key.pem -out certs/server.cert.pem \ + * -subj "/CN=TEST SERVER CERT DO NOT USE/" -addext "subjectAltName=DNS:localhost" \ + * -addext "keyUsage=keyCertSign,digitalSignature,keyEncipherment" \ + * -addext "extendedKeyUsage=critical,serverAuth" + * + * create a client certificate with the "user" role: + * $ openssl req -x509 -newkey rsa:4096 -sha256 -days 360 -nodes \ + * -keyout certs/user-client.key.pem -out certs/user-client.cert.pem \ + * -subj "/CN=TEST CLIENT CERT DO NOT USE/" \ + * -addext "keyUsage=keyCertSign,digitalSignature,keyEncipherment" \ + * -addext "extendedKeyUsage=critical,clientAuth" \ + * -addext "1.3.6.1.4.1.50316.802.1=ASN1:UTF8String:user" + * + * create another client certificate with the "operator" role: + * $ openssl req -x509 -newkey rsa:4096 -sha256 -days 360 -nodes \ + * -keyout certs/operator-client.key.pem -out certs/operator-client.cert.pem \ + * -subj "/CN=TEST CLIENT CERT DO NOT USE/" \ + * -addext "keyUsage=keyCertSign,digitalSignature,keyEncipherment" \ + * -addext "extendedKeyUsage=critical,clientAuth" \ + * -addext "1.3.6.1.4.1.50316.802.1=ASN1:UTF8String:operator" + * + * create a file containing both client certificates (for use by the server as an + * 'allowed client list'): + * $ cat certs/user-client.cert.pem certs/operator-client.cert.pem >certs/clients.cert.pem + * + * start the server: + * $ go run examples/tls_server.go + * + * in another shell, read the clock with modbus-cli as the 'user' role: + * $ go run cmd/modbus-cli.go --target tcp+tls://localhost:5802 --cert certs/user-client.cert.pem \ + * --key certs/user-client.key.pem --ca certs/server.cert.pem rh:uint32:0 + * + * attempting to set the clock as 'user' should fail with Illegal Function: + * $ go run cmd/modbus-cli.go --target tcp+tls://localhost:5802 --cert certs/user-client.cert.pem \ + * --key certs/user-client.key.pem --ca certs/server.cert.pem wr:uint32:0:1598692358 + * + * setting the clock as 'operator' should succeed: + * $ go run cmd/modbus-cli.go --target tcp+tls://localhost:5802 --cert certs/operator-client.cert.pem \ + * --key certs/operator-client.key.pem --ca certs/server.cert.pem wr:uint32:0:1598692358 + * + * reading the cock as 'operator' should also work: + * $ go run cmd/modbus-cli.go --target tcp+tls://localhost:5802 --cert certs/operator-client.cert.pem \ + * --key certs/operator-client.key.pem --ca certs/server.cert.pem rh:uint32:0 + */ + +func main() { + var err error + var eh *exampleHandler + var server *modbus.ModbusServer + var serverKeyPair tls.Certificate + var clientCertPool *x509.CertPool + var ticker *time.Ticker + + // create the handler object + eh = &exampleHandler{} + + // load the server certificate and its associated private key, which + // are used to authenticate the server to the client. + // note that a tls.Certificate object can contain both the cert and its key, + // which is the case here. + serverKeyPair, err = tls.LoadX509KeyPair( + "certs/server.cert.pem", "certs/server.key.pem") + if err != nil { + fmt.Printf("failed to load server key pair: %v\n", err) + os.Exit(1) + } + + // load TLS client authentication material, which could either be: + // - the CA (Certificate Authority) certificate(s) used to sign client certs, + // - the list of allowed client certs, if client certificates are self-signed or + // if client certificate pinning is required. + clientCertPool, err = modbus.LoadCertPool("certs/clients.cert.pem") + if err != nil { + fmt.Printf("failed to load CA/client certificates: %v\n", err) + os.Exit(1) + } + + // create the server object + server, err = modbus.NewServer(&modbus.ServerConfiguration{ + // listen on localhost port 5802 + URL: "tcp+tls://localhost:5802", + // accept 10 concurrent connections max. + MaxClients: 10, + // close idle connections after 1min of inactivity + Timeout: 60 * time.Second, + // use serverKeyPair as server certificate + server private key + TLSServerCert: &serverKeyPair, + // use the client cert/CA pool to verify client certificates + TLSClientCAs: clientCertPool, + }, eh) + if err != nil { + fmt.Printf("failed to create server: %v\n", err) + os.Exit(1) + } + + // start accepting client connections + // note that Start() returns as soon as the server is started + err = server.Start() + if err != nil { + fmt.Printf("failed to start server: %v\n", err) + os.Exit(1) + } + fmt.Println("server started") + + ticker = time.NewTicker(1 * time.Second) + for { + <-ticker.C + + // increment the clock every second. + // lock the handler object while updating the clock register to avoid + // concurrency issues as each client is served from a dedicated goroutine. + eh.lock.Lock() + eh.clock++ + eh.lock.Unlock() + } + + // never reached + + return +} + +// Example handler object, passed to the NewServer() constructor above. +type exampleHandler struct { + // this lock is used to avoid concurrency issues between goroutines, as + // handler methods are called from different goroutines + // (1 goroutine per client) + lock sync.RWMutex + + // unix timestamp register, incremented in the main() function above and exposed + // as a 32-bit holding register (2 consecutive 16-bit modbus registers). + clock uint32 +} + +// Holding register handler method. +// This method gets called whenever a valid modbus request asking for a holding register +// operation is received by the server. +func (eh *exampleHandler) HandleHoldingRegisters(req *modbus.HoldingRegistersRequest) (res []uint16, err error) { + var regAddr uint16 + + // require the "operator" role for write operations (i.e. set the clock). + if req.IsWrite && req.ClientRole != "operator" { + fmt.Printf("write access denied: client %s missing the 'operator' role (role: '%s')\n", + req.ClientAddr, req.ClientRole) + err = modbus.ErrIllegalFunction + return + } + + // since we're manipulating variables accessed from multiple goroutines, + // acquire a lock to avoid concurrency issues. + eh.lock.Lock() + // release the lock upon return + defer eh.lock.Unlock() + + // loop through `quantity` registers + for i := 0; i < int(req.Quantity); i++ { + // compute the target register address + regAddr = req.Addr + uint16(i) + + switch regAddr { + // expose the 16 most-significant bits of the clock in register #0 + case 0: + if req.IsWrite { + eh.clock = + ((uint32(req.Args[i]) << 16) & 0xffff0000 | + (eh.clock & 0x0000ffff)) + } + res = append(res, uint16((eh.clock >> 16) & 0x0000ffff)) + + // expose the 16 least-significant bits of the clock in register #1 + case 1: + if req.IsWrite { + eh.clock = + (uint32(req.Args[i]) & 0x0000ffff | + (eh.clock & 0xffff0000)) + } + res = append(res, uint16(eh.clock & 0x0000ffff)) + + // any other address is unknown + default: + err = modbus.ErrIllegalDataAddress + return + } + } + + return +} + +// input registers are not used by this server. +func (eh *exampleHandler) HandleInputRegisters(req *modbus.InputRegistersRequest) (res []uint16, err error) { + // this is the equivalent of saying + // "input registers are not supported by this device" + err = modbus.ErrIllegalFunction + return +} + +// coils are not used by this server. +func (eh *exampleHandler) HandleCoils(req *modbus.CoilsRequest) (res []bool, err error) { + // this is the equivalent of saying + // "coils are not supported by this device" + err = modbus.ErrIllegalFunction + return +} + +// discrete inputs are not used by this server. +func (eh *exampleHandler) HandleDiscreteInputs(req *modbus.DiscreteInputsRequest) (res []bool, err error) { + // this is the equivalent of saying + // "discrete inputs are not supported by this device" + err = modbus.ErrIllegalFunction + return +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..47ce4c1 --- /dev/null +++ b/go.mod @@ -0,0 +1,7 @@ +module git.whblueocean.cn/communication-protocol/modbus + +// module github.com/simonvetter/modbus + +go 1.16 + +require github.com/goburrow/serial v0.1.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6cb2ac4 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/goburrow/serial v0.1.0 h1:v2T1SQa/dlUqQiYIT8+Cu7YolfqAi3K96UmhwYyuSrA= +github.com/goburrow/serial v0.1.0/go.mod h1:sAiqG0nRVswsm1C97xsttiYCzSLBmUZ/VSlVLZJ8haA= diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..4aeb630 --- /dev/null +++ b/logger.go @@ -0,0 +1,81 @@ +package modbus + +import ( + "fmt" + "os" + "log" +) + +type logger struct { + prefix string + customLogger *log.Logger +} + +func newLogger(prefix string, customLogger *log.Logger) (l *logger) { + l = &logger{ + prefix: prefix, + customLogger: customLogger, + } + + return +} + +func (l *logger) Info(msg string) { + l.write(fmt.Sprintf("%s [info]: %s\n", l.prefix, msg)) + + return +} + +func (l *logger) Infof(format string, msg ...interface{}) { + l.write(fmt.Sprintf("%s [info]: %s\n", l.prefix, fmt.Sprintf(format, msg...))) + + return +} + +func (l *logger) Warning(msg string) { + l.write(fmt.Sprintf("%s [warn]: %s\n", l.prefix, msg)) + + return +} + +func (l *logger) Warningf(format string, msg ...interface{}) { + l.write(fmt.Sprintf("%s [warn]: %s\n", l.prefix, fmt.Sprintf(format, msg...))) + + return +} + +func (l *logger) Error(msg string) { + l.write(fmt.Sprintf("%s [error]: %s\n", l.prefix, msg)) + + return +} + +func (l *logger) Errorf(format string, msg ...interface{}) { + l.write(fmt.Sprintf("%s [error]: %s\n", l.prefix, fmt.Sprintf(format, msg...))) + + return +} + +func (l *logger) Fatal(msg string) { + l.Error(msg) + os.Exit(1) + + return +} + +func (l *logger) Fatalf(format string, msg ...interface{}) { + l.Errorf(format, msg...) + os.Exit(1) + + return +} + +func (l *logger) write(msg string) { + if l.customLogger == nil { + os.Stdout.WriteString(msg) + } else { + l.customLogger.Print(msg) + } + + return +} diff --git a/logger_test.go b/logger_test.go new file mode 100644 index 0000000..82d7262 --- /dev/null +++ b/logger_test.go @@ -0,0 +1,43 @@ +package modbus + +import ( + "bytes" + "log" + "testing" +) + +func TestClientCustomLogger(t *testing.T) { + var buf bytes.Buffer + var logger *log.Logger + + logger = log.New(&buf, "external-prefix: ", 0) + + _, _ = NewClient(&ClientConfiguration{ + Logger: logger, + URL: "sometype://sometarget", + }) + + if buf.String() != "external-prefix: modbus-client(sometarget) [error]: unsupported client type 'sometype'\n" { + t.Errorf("unexpected logger output '%s'", buf.String()) + } + + return +} + +func TestServerCustomLogger(t *testing.T) { + var buf bytes.Buffer + var logger *log.Logger + + logger = log.New(&buf, "external-prefix: ", 0) + + _, _ = NewServer(&ServerConfiguration{ + Logger: logger, + URL: "tcp://", + }, nil) + + if buf.String() != "external-prefix: modbus-server() [error]: missing host part in URL 'tcp://'\n" { + t.Errorf("unexpected logger output '%s'", buf.String()) + } + + return +} diff --git a/modbus.go b/modbus.go new file mode 100644 index 0000000..f527f77 --- /dev/null +++ b/modbus.go @@ -0,0 +1,132 @@ +package modbus + +import ( + "fmt" +) + +type pdu struct { + unitId uint8 + functionCode uint8 + payload []byte +} + +type Error string + +// Error implements the error interface. +func (me Error) Error() (s string) { + s = string(me) + return +} + +const ( + // coils + fcReadCoils uint8 = 0x01 + fcWriteSingleCoil uint8 = 0x05 + fcWriteMultipleCoils uint8 = 0x0f + + // discrete inputs + fcReadDiscreteInputs uint8 = 0x02 + + // 16-bit input/holding registers + fcReadHoldingRegisters uint8 = 0x03 + fcReadInputRegisters uint8 = 0x04 + fcWriteSingleRegister uint8 = 0x06 + fcWriteMultipleRegisters uint8 = 0x10 + fcMaskWriteRegister uint8 = 0x16 + fcReadWriteMultipleRegisters uint8 = 0x17 + fcReadFifoQueue uint8 = 0x18 + + // file access + fcReadFileRecord uint8 = 0x14 + fcWriteFileRecord uint8 = 0x15 + + // customize + fcCustomize uint8 = 0x29 + + // exception codes + exIllegalFunction uint8 = 0x01 + exIllegalDataAddress uint8 = 0x02 + exIllegalDataValue uint8 = 0x03 + exServerDeviceFailure uint8 = 0x04 + exAcknowledge uint8 = 0x05 + exServerDeviceBusy uint8 = 0x06 + exMemoryParityError uint8 = 0x08 + exGWPathUnavailable uint8 = 0x0a + exGWTargetFailedToRespond uint8 = 0x0b + + // errors + ErrConfigurationError Error = "configuration error" + ErrRequestTimedOut Error = "request timed out" + ErrIllegalFunction Error = "illegal function" + ErrIllegalDataAddress Error = "illegal data address" + ErrIllegalDataValue Error = "illegal data value" + ErrServerDeviceFailure Error = "server device failure" + ErrAcknowledge Error = "request acknowledged" + ErrServerDeviceBusy Error = "server device busy" + ErrMemoryParityError Error = "memory parity error" + ErrGWPathUnavailable Error = "gateway path unavailable" + ErrGWTargetFailedToRespond Error = "gateway target device failed to respond" + ErrBadCRC Error = "bad crc" + ErrShortFrame Error = "short frame" + ErrProtocolError Error = "protocol error" + ErrBadUnitId Error = "bad unit id" + ErrBadTransactionId Error = "bad transaction id" + ErrUnknownProtocolId Error = "unknown protocol identifier" + ErrUnexpectedParameters Error = "unexpected parameters" +) + +// mapExceptionCodeToError turns a modbus exception code into a higher level Error object. +func mapExceptionCodeToError(exceptionCode uint8) (err error) { + switch exceptionCode { + case exIllegalFunction: + err = ErrIllegalFunction + case exIllegalDataAddress: + err = ErrIllegalDataAddress + case exIllegalDataValue: + err = ErrIllegalDataValue + case exServerDeviceFailure: + err = ErrServerDeviceFailure + case exAcknowledge: + err = ErrAcknowledge + case exMemoryParityError: + err = ErrMemoryParityError + case exServerDeviceBusy: + err = ErrServerDeviceBusy + case exGWPathUnavailable: + err = ErrGWPathUnavailable + case exGWTargetFailedToRespond: + err = ErrGWTargetFailedToRespond + default: + err = fmt.Errorf("unknown exception code (%v)", exceptionCode) + } + + return +} + +// mapErrorToExceptionCode turns an Error object into a modbus exception code. +func mapErrorToExceptionCode(err error) (exceptionCode uint8) { + switch err { + case ErrIllegalFunction: + exceptionCode = exIllegalFunction + case ErrIllegalDataAddress: + exceptionCode = exIllegalDataAddress + case ErrIllegalDataValue: + exceptionCode = exIllegalDataValue + case ErrServerDeviceFailure: + exceptionCode = exServerDeviceFailure + case ErrAcknowledge: + exceptionCode = exAcknowledge + case ErrMemoryParityError: + exceptionCode = exMemoryParityError + case ErrServerDeviceBusy: + exceptionCode = exServerDeviceBusy + case ErrGWPathUnavailable: + exceptionCode = exGWPathUnavailable + case ErrGWTargetFailedToRespond: + exceptionCode = exGWTargetFailedToRespond + default: + exceptionCode = exServerDeviceFailure + } + + return +} diff --git a/rtu_transport.go b/rtu_transport.go new file mode 100644 index 0000000..fe1190f --- /dev/null +++ b/rtu_transport.go @@ -0,0 +1,269 @@ +package modbus + +import ( + "fmt" + "io" + "log" + "time" +) + +const ( + maxRTUFrameLength int = 256 +) + +type rtuTransport struct { + logger *logger + link rtuLink + timeout time.Duration + lastActivity time.Time + t35 time.Duration + t1 time.Duration +} + +type rtuLink interface { + Close() (error) + Read([]byte) (int, error) + Write([]byte) (int, error) + SetDeadline(time.Time) (error) +} + +// Returns a new RTU transport. +func newRTUTransport(link rtuLink, addr string, speed uint, timeout time.Duration, customLogger *log.Logger) (rt *rtuTransport) { + rt = &rtuTransport{ + logger: newLogger(fmt.Sprintf("rtu-transport(%s)", addr), customLogger), + link: link, + timeout: timeout, + t1: serialCharTime(speed), + } + + if speed >= 19200 { + // for baud rates equal to or greater than 19200 bauds, a fixed value of + // 1750 uS is specified for t3.5. + rt.t35 = 1750 * time.Microsecond + } else { + // for lower baud rates, the inter-frame delay should be 3.5 character times + rt.t35 = (serialCharTime(speed) * 35) / 10 + } + + return +} + +// Closes the rtu link. +func (rt *rtuTransport) Close() (err error) { + err = rt.link.Close() + + return +} + +// Runs a request across the rtu link and returns a response. +func (rt *rtuTransport) ExecuteRequest(req *pdu) (res *pdu, err error) { + var ts time.Time + var t time.Duration + var n int + + // set an i/o deadline on the link + err = rt.link.SetDeadline(time.Now().Add(rt.timeout)) + if err != nil { + return + } + + // if the line was active less than 3.5 char times ago, + // let t3.5 expire before transmitting + t = time.Since(rt.lastActivity.Add(rt.t35)) + if t < 0 { + time.Sleep(t * (-1)) + } + + ts = time.Now() + + // build an RTU ADU out of the request object and + // send the final ADU+CRC on the wire + n, err = rt.link.Write(rt.assembleRTUFrame(req)) + if err != nil { + return + } + + // estimate how long the serial line was busy for. + // note that on most platforms, Write() will be buffered and return + // immediately rather than block until the buffer is drained + rt.lastActivity = ts.Add(time.Duration(n) * rt.t1) + + // observe inter-frame delays + time.Sleep(rt.lastActivity.Add(rt.t35).Sub(time.Now())) + + // read the response back from the wire + res, err = rt.readRTUFrame() + + if err == ErrBadCRC || err == ErrProtocolError || err == ErrShortFrame { + // wait for and flush any data coming off the link to allow + // devices to re-sync + time.Sleep(time.Duration(maxRTUFrameLength) * rt.t1) + discard(rt.link) + } + + // mark the time if we heard anything back + if err != ErrRequestTimedOut { + rt.lastActivity = time.Now() + } + + return +} + +// Reads a request from the rtu link. +func (rt *rtuTransport) ReadRequest() (req *pdu, err error) { + // reading requests from RTU links is currently unsupported + err = fmt.Errorf("unimplemented") + + return +} + +// Writes a response to the rtu link. +func (rt *rtuTransport) WriteResponse(res *pdu) (err error) { + var n int + + // build an RTU ADU out of the request object and + // send the final ADU+CRC on the wire + n, err = rt.link.Write(rt.assembleRTUFrame(res)) + if err != nil { + return + } + + rt.lastActivity = time.Now().Add(rt.t1 * time.Duration(n)) + + return +} + +// Waits for, reads and decodes a frame from the rtu link. +func (rt *rtuTransport) readRTUFrame() (res *pdu, err error) { + var rxbuf []byte + var byteCount int + var bytesNeeded int + var crc crc + + rxbuf = make([]byte, maxRTUFrameLength) + + // read the serial ADU header: unit id (1 byte), function code (1 byte) and + // PDU length/exception code (1 byte) + byteCount, err = io.ReadFull(rt.link, rxbuf[0:3]) + if (byteCount > 0 || err == nil) && byteCount != 3 { + err = ErrShortFrame + return + } + if err != nil && err != io.ErrUnexpectedEOF { + return + } + + // figure out how many further bytes to read + bytesNeeded, err = expectedResponseLenth(uint8(rxbuf[1]), uint8(rxbuf[2])) + if err != nil { + return + } + + // we need to read 2 additional bytes of CRC after the payload + bytesNeeded += 2 + + // never read more than the max allowed frame length + if byteCount + bytesNeeded > maxRTUFrameLength { + err = ErrProtocolError + return + } + + byteCount, err = io.ReadFull(rt.link, rxbuf[3:3 + bytesNeeded]) + if err != nil && err != io.ErrUnexpectedEOF { + return + } + if byteCount != bytesNeeded { + rt.logger.Warningf("expected %v bytes, received %v", bytesNeeded, byteCount) + err = ErrShortFrame + return + } + + // compute the CRC on the entire frame, excluding the CRC + crc.init() + crc.add(rxbuf[0:3 + bytesNeeded - 2]) + + // compare CRC values + if !crc.isEqual(rxbuf[3 + bytesNeeded - 2], rxbuf[3 + bytesNeeded - 1]) { + err = ErrBadCRC + return + } + + res = &pdu{ + unitId: rxbuf[0], + functionCode: rxbuf[1], + // pass the byte count + trailing data as payload, withtout the CRC + payload: rxbuf[2:3 + bytesNeeded - 2], + } + + return +} + +// Turns a PDU object into bytes. +func (rt *rtuTransport) assembleRTUFrame(p *pdu) (adu []byte) { + var crc crc + + adu = append(adu, p.unitId) + adu = append(adu, p.functionCode) + adu = append(adu, p.payload...) + + // run the ADU through the CRC generator + crc.init() + crc.add(adu) + + // append the CRC to the ADU + adu = append(adu, crc.value()...) + + return +} + +// Computes the expected length of a modbus RTU response. +func expectedResponseLenth(responseCode uint8, responseLength uint8) (byteCount int, err error) { + switch responseCode { + case fcReadHoldingRegisters, + fcReadInputRegisters, + fcReadCoils, + fcReadDiscreteInputs: byteCount = int(responseLength) + case fcWriteSingleRegister, + fcWriteMultipleRegisters, + fcWriteSingleCoil, + fcWriteMultipleCoils: byteCount = 3 + case fcMaskWriteRegister: byteCount = 5 + case fcReadHoldingRegisters | 0x80, + fcReadInputRegisters | 0x80, + fcReadCoils | 0x80, + fcReadDiscreteInputs | 0x80, + fcWriteSingleRegister | 0x80, + fcWriteMultipleRegisters | 0x80, + fcWriteSingleCoil | 0x80, + fcWriteMultipleCoils | 0x80, + fcMaskWriteRegister | 0x80: byteCount = 0 + default: err = ErrProtocolError + } + + return +} + +// Discards the contents of the link's rx buffer, eating up to 1kB of data. +// Note that on a serial line, this call may block for up to serialConf.Timeout +// i.e. 10ms. +func discard(link rtuLink) { + var rxbuf = make([]byte, 1024) + + link.SetDeadline(time.Now().Add(500 * time.Microsecond)) + io.ReadFull(link, rxbuf) + + return +} + +// Returns how long it takes to send 1 byte on a serial line at the +// specified baud rate. +func serialCharTime(rate_bps uint) (ct time.Duration) { + // note: an RTU byte on the wire is: + // - 1 start bit, + // - 8 data bits, + // - 1 parity or stop bit + // - 1 stop bit + ct = (11) * time.Second / time.Duration(rate_bps) + + return +} diff --git a/rtu_transport_test.go b/rtu_transport_test.go new file mode 100644 index 0000000..f4f950c --- /dev/null +++ b/rtu_transport_test.go @@ -0,0 +1,189 @@ +package modbus + +import ( + "testing" + "io" + "net" + "time" +) + +func TestAssembleRTUFrame(t *testing.T) { + var rt *rtuTransport + var frame []byte + + rt = &rtuTransport{} + + frame = rt.assembleRTUFrame(&pdu{ + unitId: 0x33, + functionCode: 0x11, + payload: []byte{0x22, 0x33, 0x44, 0x55}, + }) + // expect 1 byte of unit id, 1 byte of function code, 4 bytes of payload and + // 2 bytes of CRC + if len(frame) != 8 { + t.Errorf("expected 8 bytes, got %v", len(frame)) + } + for i, b := range []byte{ + 0x33, 0x11, // unit id and function code + 0x22, 0x33, // payload + 0x44, 0x55, // payload + 0xf0, 0x93, // CRC + } { + if frame[i] != b { + t.Errorf("expected 0x%02x at position %v, got 0x%02x", b, i, frame[i]) + } + } + + frame = rt.assembleRTUFrame(&pdu{ + unitId: 0x31, + functionCode: 0x06, + payload: []byte{0x12, 0x34}, + }) + // expect 1 byte of unit if, 1 byte of function code, 2 bytes of payload and + // 2 bytes of CRC + if len(frame) != 6 { + t.Errorf("expected 6 bytes, got %v", len(frame)) + } + for i, b := range []byte{ + 0x31, 0x06, // unit id and function code + 0x12, 0x34, // payload + 0xe3, 0xae, // CRC + } { + if frame[i] != b { + t.Errorf("expected 0x%02x at position %v, got 0x%02x", b, i, frame[i]) + } + } + + return +} + + +func TestRTUTransportReadRTUFrame(t *testing.T) { + var rt *rtuTransport + var p1, p2 net.Conn + var txchan chan []byte + var err error + var res *pdu + + txchan = make(chan []byte, 2) + p1, p2 = net.Pipe() + go feedTestPipe(t, txchan, p1) + + + rt = newRTUTransport(p2, "", 9600, 10 * time.Millisecond, nil) + + // read a valid response (illegal data address) + txchan <- []byte{ + 0x31, 0x82, // unit id and response code + 0x02, // exception code + 0xc1, 0x6e, // CRC + } + res, err = rt.readRTUFrame() + if err != nil { + t.Errorf("readRTUFrame() should have succeeded, got %v", err) + } + if res.unitId != 0x31 { + t.Errorf("expected 0x31 as unit id, got 0x%02x", res.unitId) + } + if res.functionCode != 0x82 { + t.Errorf("expected 0x82 as function code, got 0x%02x", res.functionCode) + } + if len(res.payload) != 1 { + t.Errorf("expected a length of 1, got %v", len(res.payload)) + } + if res.payload[0] != 0x02 { + t.Errorf("expected {0x02} as payload, got {0x%02x}", + res.payload[0]) + } + + // read a frame with a bad crc + txchan <- []byte{ + 0x30, 0x82, // unit id and response code + 0x12, // exception code + 0xc0, 0xa2, // CRC + } + res, err = rt.readRTUFrame() + if err != ErrBadCRC { + t.Errorf("readRTUFrame() should have returned ErrBadCrc, got %v", err) + } + + // read a longer, valid response + txchan <- []byte{ + 0x31, 0x03, // unit id and response code + 0x04, // length + 0x11, 0x22, // register #1 + 0x33, 0x44, // register #2 + 0x7b, 0xc5, // CRC + } + res, err = rt.readRTUFrame() + if err != nil { + t.Errorf("readRTUFrame() should have succeeded, got %v", err) + } + if res.unitId != 0x31 { + t.Errorf("expected 0x31 as unit id, got 0x%02x", res.unitId) + } + if res.functionCode != 0x03 { + t.Errorf("expected 0x03 as function code, got 0x%02x", res.functionCode) + } + if len(res.payload) != 5 { + t.Errorf("expected a length of 5, got %v", len(res.payload)) + } + for i, b := range []byte{ + 0x04, + 0x11, 0x22, + 0x33, 0x44, + } { + if res.payload[i] != b { + t.Errorf("expected 0x%02x at position %v, got 0x%02x", + b, i, res.payload[i]) + } + } + + p1.Close() + p2.Close() + + return +} + +func feedTestPipe(t *testing.T, in chan []byte, out io.WriteCloser) { + var err error + var txbuf []byte + + for { + // grab a slice of bytes from the channel + txbuf = <-in + + // write this slice to the pipe + _, err = out.Write(txbuf) + if err != nil { + t.Errorf("failed to write to test pipe: %v", err) + return + } + } + + return +} + +func TestModbusRTUSerialCharTime(t *testing.T) { + var d time.Duration + + d = serialCharTime(38400) + // expect 11 bits at 38400bps: 11 * (1/38400) = 286.458uS + if d != time.Duration(286458) * time.Nanosecond { + t.Errorf("unexpected serial char duration: %v", d) + } + + d = serialCharTime(19200) + // expect 11 bits at 19200bps: 11 * (1/19200) = 572.916uS + if d != time.Duration(572916) * time.Nanosecond { + t.Errorf("unexpected serial char duration: %v", d) + } + + d = serialCharTime(9600) + // expect 11 bits at 9600bps: 11 * (1/9600) = 1.145833ms + if d != time.Duration(1145833) * time.Nanosecond { + t.Errorf("unexpected serial char duration: %v", d) + } + + return +} diff --git a/serial.go b/serial.go new file mode 100644 index 0000000..075b542 --- /dev/null +++ b/serial.go @@ -0,0 +1,103 @@ +package modbus + +import ( + "time" + + "github.com/goburrow/serial" +) + +// serialPortWrapper wraps a serial.Port (i.e. physical port) to +// 1) satisfy the rtuLink interface and +// 2) add Read() deadline/timeout support. +type serialPortWrapper struct { + conf *serialPortConfig + port serial.Port + deadline time.Time +} + +type serialPortConfig struct { + Device string + Speed uint + DataBits uint + Parity uint + StopBits uint +} + +func newSerialPortWrapper(conf *serialPortConfig) (spw *serialPortWrapper) { + spw = &serialPortWrapper{ + conf: conf, + } + + return +} + +func (spw *serialPortWrapper) Open() (err error) { + var parity string + + switch spw.conf.Parity { + case PARITY_NONE: parity = "N" + case PARITY_EVEN: parity = "E" + case PARITY_ODD: parity = "O" + } + + spw.port, err = serial.Open(&serial.Config{ + Address: spw.conf.Device, + BaudRate: int(spw.conf.Speed), + DataBits: int(spw.conf.DataBits), + Parity: parity, + StopBits: int(spw.conf.StopBits), + Timeout: 10 * time.Millisecond, + }) + + return +} + +// Closes the serial port. +func (spw *serialPortWrapper) Close() (err error) { + err = spw.port.Close() + + return +} + +// Reads bytes from the underlying serial port. +// If Read() is called after the deadline, a timeout error is returned without +// attempting to read from the serial port. +// If Read() is called before the deadline, a read attempt to the serial port +// is made. At this point, one of two things can happen: +// - the serial port's receive buffer has one or more bytes and port.Read() +// returns immediately (partial or full read), +// - the serial port's receive buffer is empty: port.Read() blocks for +// up to 10ms and returns serial.ErrTimeout. The serial timeout error is +// masked and Read() returns with no data. +// As the higher-level methods use io.ReadFull(), Read() will be called +// as many times as necessary until either enough bytes have been read or an +// error is returned (ErrRequestTimedOut or any other i/o error). +func (spw *serialPortWrapper) Read(rxbuf []byte) (cnt int, err error) { + // return a timeout error if the deadline has passed + if time.Now().After(spw.deadline) { + err = ErrRequestTimedOut + return + } + + cnt, err = spw.port.Read(rxbuf) + // mask serial.ErrTimeout errors from the serial port + if err != nil && err == serial.ErrTimeout { + err = nil + } + + return +} + +// Sends the bytes over the wire. +func (spw *serialPortWrapper) Write(txbuf []byte) (cnt int, err error) { + cnt, err = spw.port.Write(txbuf) + + return +} + +// Saves the i/o deadline (only used by Read). +func (spw *serialPortWrapper) SetDeadline(deadline time.Time) (err error) { + spw.deadline = deadline + + return +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..0124890 --- /dev/null +++ b/server.go @@ -0,0 +1,901 @@ +package modbus + +import ( + "crypto/tls" + "crypto/x509" + "encoding/asn1" + "errors" + "fmt" + "log" + "net" + "strings" + "sync" + "time" +) + +// Modbus Role PEM OID (see R-21 of the MBAPS spec) +var modbusRoleOID asn1.ObjectIdentifier = asn1.ObjectIdentifier{ + 1, 3, 6, 1, 4, 1, 50316, 802, 1, +} + +// Server configuration object. +type ServerConfiguration struct { + // URL defines where to listen at e.g. tcp://[::]:502 + URL string + // Timeout sets the idle session timeout (client connections will + // be closed if idle for this long) + Timeout time.Duration + // MaxClients sets the maximum number of concurrent client connections + MaxClients uint + // TLSServerCert sets the server-side TLS key pair (tcp+tls only) + TLSServerCert *tls.Certificate + // TLSClientCAs sets the list of CA certificates used to authenticate + // client connections (tcp+tls only). Leaf (i.e. client) certificates can + // also be used in case of self-signed certs, or if cert pinning is required. + TLSClientCAs *x509.CertPool + // Logger provides a custom sink for log messages. + // If nil, messages will be written to stdout. + Logger *log.Logger +} + +// Request object passed to the coil handler. +type CoilsRequest struct { + ClientAddr string // the source (client) IP address + ClientRole string // the client role as encoded in the client certificate (tcp+tls only) + UnitId uint8 // the requested unit id (slave id) + Addr uint16 // the base coil address requested + Quantity uint16 // the number of consecutive coils covered by this request + // (first address: Addr, last address: Addr + Quantity - 1) + IsWrite bool // true if the request is a write, false if a read + Args []bool // a slice of bool values of the coils to be set, ordered + // from Addr to Addr + Quantity - 1 (for writes only) +} + +// Request object passed to the discrete input handler. +type DiscreteInputsRequest struct { + ClientAddr string // the source (client) IP address + ClientRole string // the client role as encoded in the client certificate (tcp+tls only) + UnitId uint8 // the requested unit id (slave id) + Addr uint16 // the base discrete input address requested + Quantity uint16 // the number of consecutive discrete inputs covered by this request +} + +// Request object passed to the holding register handler. +type HoldingRegistersRequest struct { + ClientAddr string // the source (client) IP address + ClientRole string // the client role as encoded in the client certificate (tcp+tls only) + UnitId uint8 // the requested unit id (slave id) + Addr uint16 // the base register address requested + Quantity uint16 // the number of consecutive registers covered by this request + IsWrite bool // true if the request is a write, false if a read + Args []uint16 // a slice of register values to be set, ordered from + // Addr to Addr + Quantity - 1 (for writes only) +} + +// Request object passed to the input register handler. +type InputRegistersRequest struct { + ClientAddr string // the source (client) IP address + ClientRole string // the client role as encoded in the client certificate (tcp+tls only) + UnitId uint8 // the requested unit id (slave id) + Addr uint16 // the base register address requested + Quantity uint16 // the number of consecutive registers covered by this request +} + +// The RequestHandler interface should be implemented by the handler +// object passed to NewServer (see reqHandler in NewServer()). +// After decoding and validating an incoming request, the server will +// invoke the appropriate handler function, depending on the function code +// of the request. +type RequestHandler interface { + // HandleCoils handles the read coils (0x01), write single coil (0x05) + // and write multiple coils (0x0f) function codes. + // A CoilsRequest object is passed to the handler (see above). + // + // Expected return values: + // - res: a slice of bools containing the coil values to be sent to back + // to the client (only sent for reads), + // - err: either nil if no error occurred, a modbus error (see + // mapErrorToExceptionCode() in modbus.go for a complete list), + // or any other error. + // If nil, a positive modbus response is sent back to the client + // along with the returned data. + // If non-nil, a negative modbus response is sent back, with the + // exception code set depending on the error + // (again, see mapErrorToExceptionCode()). + HandleCoils (req *CoilsRequest) (res []bool, err error) + + // HandleDiscreteInputs handles the read discrete inputs (0x02) function code. + // A DiscreteInputsRequest oibject is passed to the handler (see above). + // + // Expected return values: + // - res: a slice of bools containing the discrete input values to be + // sent back to the client, + // - err: either nil if no error occurred, a modbus error (see + // mapErrorToExceptionCode() in modbus.go for a complete list), + // or any other error. + HandleDiscreteInputs (req *DiscreteInputsRequest) (res []bool, err error) + + // HandleHoldingRegisters handles the read holding registers (0x03), + // write single register (0x06) and write multiple registers (0x10). + // A HoldingRegistersRequest object is passed to the handler (see above). + // + // Expected return values: + // - res: a slice of uint16 containing the register values to be sent + // to back to the client (only sent for reads), + // - err: either nil if no error occurred, a modbus error (see + // mapErrorToExceptionCode() in modbus.go for a complete list), + // or any other error. + HandleHoldingRegisters (req *HoldingRegistersRequest) (res []uint16, err error) + + // HandleInputRegisters handles the read input registers (0x04) function code. + // An InputRegistersRequest object is passed to the handler (see above). + // + // Expected return values: + // - res: a slice of uint16 containing the register values to be sent + // back to the client, + // - err: either nil if no error occurred, a modbus error (see + // mapErrorToExceptionCode() in modbus.go for a complete list), + // or any other error. + HandleInputRegisters (req *InputRegistersRequest) (res []uint16, err error) +} + +// Modbus server object. +type ModbusServer struct { + conf ServerConfiguration + logger *logger + lock sync.Mutex + started bool + handler RequestHandler + tcpListener net.Listener + tcpClients []net.Conn + transportType transportType +} + +// Returns a new modbus server. +// reqHandler should be a user-provided handler object satisfying the RequestHandler +// interface. +func NewServer(conf *ServerConfiguration, reqHandler RequestHandler) ( + ms *ModbusServer, err error) { + var serverType string + var splitURL []string + + ms = &ModbusServer{ + conf: *conf, + handler: reqHandler, + } + + splitURL = strings.SplitN(ms.conf.URL, "://", 2) + if len(splitURL) == 2 { + serverType = splitURL[0] + ms.conf.URL = splitURL[1] + } + + ms.logger = newLogger( + fmt.Sprintf("modbus-server(%s)", ms.conf.URL), ms.conf.Logger) + + if ms.conf.URL == "" { + ms.logger.Errorf("missing host part in URL '%s'", conf.URL) + err = ErrConfigurationError + return + } + + switch serverType { + case "tcp": + if ms.conf.Timeout == 0 { + ms.conf.Timeout = 120 * time.Second + } + + if ms.conf.MaxClients == 0 { + ms.conf.MaxClients = 10 + } + + ms.transportType = modbusTCP + + case "tcp+tls": + if ms.conf.Timeout == 0 { + ms.conf.Timeout = 120 * time.Second + } + + if ms.conf.MaxClients == 0 { + ms.conf.MaxClients = 10 + } + + // expect a server-side certificate + if ms.conf.TLSServerCert == nil { + ms.logger.Errorf("missing server certificate") + err = ErrConfigurationError + return + } + + // expect a CertPool object containing at least 1 CA or + // leaf certificate to validate client-side certificates + if ms.conf.TLSClientCAs == nil { + ms.logger.Errorf("missing CA/client certificates") + err = ErrConfigurationError + return + } + + ms.transportType = modbusTCPOverTLS + + default: + err = ErrConfigurationError + return + } + + return +} + +// Starts accepting client connections. +func (ms *ModbusServer) Start() (err error) { + ms.lock.Lock() + defer ms.lock.Unlock() + + if ms.started { + return + } + + switch ms.transportType { + case modbusTCP, modbusTCPOverTLS: + // bind to a TCP socket + ms.tcpListener, err = net.Listen("tcp", ms.conf.URL) + if err != nil { + return + } + + // accept client connections in a goroutine + go ms.acceptTCPClients() + + default: + err = ErrConfigurationError + return + } + + ms.started = true + + return +} + +// Stops accepting new client connections and closes any active session. +func (ms *ModbusServer) Stop() (err error) { + ms.lock.Lock() + defer ms.lock.Unlock() + + if !ms.started { + return + } + + ms.started = false + + if ms.transportType == modbusTCP || ms.transportType == modbusTCPOverTLS { + // close the server socket if we're listening over TCP + err = ms.tcpListener.Close() + + // close all active TCP clients + for _, sock := range ms.tcpClients{ + sock.Close() + } + } + + return +} + +// Accepts new client connections if the configured connection limit allows it. +// Each connection is served from a dedicated goroutine to allow for concurrent +// connections. +func (ms *ModbusServer) acceptTCPClients() { + var sock net.Conn + var err error + var accepted bool + + for { + sock, err = ms.tcpListener.Accept() + if err != nil { + // if the server socket has just been closed, return here as + // this goroutine isn't going to see any new client connection + if errors.Is(err, net.ErrClosed) { + return + } + ms.logger.Warningf("failed to accept client connection: %v", err) + continue + } + + ms.lock.Lock() + // apply a connection limit + if ms.started && uint(len(ms.tcpClients)) < ms.conf.MaxClients { + accepted = true + // add the new client connection to the pool + ms.tcpClients = append(ms.tcpClients, sock) + } else { + accepted = false + } + ms.lock.Unlock() + + if accepted { + // spin a client handler goroutine to serve the new client + go ms.handleTCPClient(sock) + } else { + ms.logger.Warningf("max. number of concurrent connections " + + "reached, rejecting %v", sock.RemoteAddr()) + // discard the connection + sock.Close() + } + } + + // never reached + return +} + +// Handles a TCP client connection. +// Once handleTransport() returns (i.e. the connection has either closed, timed +// out, or an unrecoverable error happened), the TCP socket is closed and removed +// from the list of active client connections. +func (ms *ModbusServer) handleTCPClient(sock net.Conn) { + var err error + var clientRole string + var tlsSock net.Conn + + switch ms.transportType { + case modbusTCP: + // serve modbus requests over the raw TCP connection + ms.handleTransport( + newTCPTransport(sock, ms.conf.Timeout, ms.conf.Logger), + sock.RemoteAddr().String(), "") + + case modbusTCPOverTLS: + // start TLS negotiation over the raw TCP connection + tlsSock, clientRole, err = ms.startTLS(sock) + if err != nil { + ms.logger.Warningf("TLS handshake with %s failed: %v", + sock.RemoteAddr().String(), err) + } else { + // serve modbus requests over the TLS tunnel + ms.handleTransport( + newTCPTransport(tlsSock, ms.conf.Timeout, ms.conf.Logger), + sock.RemoteAddr().String(), clientRole) + } + + default: + ms.logger.Errorf("unimplemented transport type %v", ms.transportType) + } + + // once done, remove our connection from the list of active client conns + ms.lock.Lock() + for i := range ms.tcpClients { + if ms.tcpClients[i] == sock { + ms.tcpClients[i] = ms.tcpClients[len(ms.tcpClients)-1] + ms.tcpClients = ms.tcpClients[:len(ms.tcpClients)-1] + break + } + } + ms.lock.Unlock() + + // close the connection + sock.Close() + + return +} + +// For each request read from the transport, performs decoding and validation, +// calls the user-provided handler, then encodes and writes the response +// to the transport. +func (ms *ModbusServer) handleTransport(t transport, clientAddr string, clientRole string) { + var req *pdu + var res *pdu + var err error + var addr uint16 + var quantity uint16 + + for { + req, err = t.ReadRequest() + if err != nil { + return + } + + switch req.functionCode { + case fcReadCoils, fcReadDiscreteInputs: + var coils []bool + var resCount int + + if len(req.payload) != 4 { + err = ErrProtocolError + break + } + + // decode address and quantity fields + addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2]) + quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4]) + + // ensure the reply never exceeds the maximum PDU length and we + // never read past 0xffff + if quantity > 2000 || quantity == 0 { + err = ErrProtocolError + break + } + if uint32(addr) + uint32(quantity) - 1 > 0xffff { + err = ErrIllegalDataAddress + break + } + + // invoke the appropriate handler + if req.functionCode == fcReadCoils { + coils, err = ms.handler.HandleCoils(&CoilsRequest{ + ClientAddr: clientAddr, + ClientRole: clientRole, + UnitId: req.unitId, + Addr: addr, + Quantity: quantity, + IsWrite: false, + Args: nil, + }) + } else { + coils, err = ms.handler.HandleDiscreteInputs( + &DiscreteInputsRequest{ + ClientAddr: clientAddr, + ClientRole: clientRole, + UnitId: req.unitId, + Addr: addr, + Quantity: quantity, + }) + } + resCount = len(coils) + + // make sure the handler returned the expected number of items + if err == nil && resCount != int(quantity) { + ms.logger.Errorf("handler returned %v bools, " + + "expected %v", resCount, quantity) + err = ErrServerDeviceFailure + break + } + + if err != nil { + break + } + + // assemble a response PDU + res = &pdu{ + unitId: req.unitId, + functionCode: req.functionCode, + payload: []byte{0}, + } + + // byte count (1 byte for 8 coils) + res.payload[0] = uint8(resCount / 8) + if resCount % 8 != 0 { + res.payload[0]++ + } + + // coil values + res.payload = append(res.payload, encodeBools(coils)...) + + case fcWriteSingleCoil: + if len(req.payload) != 4 { + err = ErrProtocolError + break + } + + // decode the address field + addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2]) + + // validate the value field (should be either 0xff00 or 0x0000) + if ((req.payload[2] != 0xff && req.payload[2] != 0x00) || + req.payload[3] != 0x00) { + err = ErrProtocolError + break + } + + // invoke the coil handler + _, err = ms.handler.HandleCoils(&CoilsRequest{ + ClientAddr: clientAddr, + ClientRole: clientRole, + UnitId: req.unitId, + Addr: addr, + Quantity: 1, // request for a single coil + IsWrite: true, // this is a write request + Args: []bool{(req.payload[2] == 0xff)}, + }) + + if err != nil { + break + } + + // assemble a response PDU + res = &pdu{ + unitId: req.unitId, + functionCode: req.functionCode, + } + + // echo the address and value in the response + res.payload = append(res.payload, + uint16ToBytes(BIG_ENDIAN, addr)...) + res.payload = append(res.payload, + req.payload[2], req.payload[3]) + + case fcWriteMultipleCoils: + var expectedLen int + + if len(req.payload) < 6 { + err = ErrProtocolError + break + } + + // decode address and quantity fields + addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2]) + quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4]) + + // ensure the reply never exceeds the maximum PDU length and we + // never read past 0xffff + if quantity > 0x7b0 || quantity == 0 { + err = ErrProtocolError + break + } + if uint32(addr) + uint32(quantity) - 1 > 0xffff { + err = ErrIllegalDataAddress + break + } + + // validate the byte count field (1 byte for 8 coils) + expectedLen = int(quantity) / 8 + if quantity % 8 != 0 { + expectedLen++ + } + + if req.payload[4] != uint8(expectedLen) { + err = ErrProtocolError + break + } + + // make sure we have enough bytes + if len(req.payload) - 5 != expectedLen { + err = ErrProtocolError + break + } + + // invoke the coil handler + _, err = ms.handler.HandleCoils(&CoilsRequest{ + ClientAddr: clientAddr, + ClientRole: clientRole, + UnitId: req.unitId, + Addr: addr, + Quantity: quantity, + IsWrite: true, // this is a write request + Args: decodeBools(quantity, req.payload[5:]), + }) + + if err != nil { + break + } + + // assemble a response PDU + res = &pdu{ + unitId: req.unitId, + functionCode: req.functionCode, + } + + // echo the address and quantity in the response + res.payload = append(res.payload, + uint16ToBytes(BIG_ENDIAN, addr)...) + res.payload = append(res.payload, + uint16ToBytes(BIG_ENDIAN, quantity)...) + + case fcReadHoldingRegisters, fcReadInputRegisters: + var regs []uint16 + var resCount int + + if len(req.payload) != 4 { + err = ErrProtocolError + break + } + + // decode address and quantity fields + addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2]) + quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4]) + + // ensure the reply never exceeds the maximum PDU length and we + // never read past 0xffff + if quantity > 0x007d || quantity == 0 { + err = ErrProtocolError + break + } + if uint32(addr) + uint32(quantity) - 1 > 0xffff { + err = ErrIllegalDataAddress + break + } + + // invoke the appropriate handler + if req.functionCode == fcReadHoldingRegisters { + regs, err = ms.handler.HandleHoldingRegisters( + &HoldingRegistersRequest{ + ClientAddr: clientAddr, + ClientRole: clientRole, + UnitId: req.unitId, + Addr: addr, + Quantity: quantity, + IsWrite: false, + Args: nil, + }) + } else { + regs, err = ms.handler.HandleInputRegisters( + &InputRegistersRequest{ + ClientAddr: clientAddr, + ClientRole: clientRole, + UnitId: req.unitId, + Addr: addr, + Quantity: quantity, + }) + } + resCount = len(regs) + + // make sure the handler returned the expected number of items + if err == nil && resCount != int(quantity) { + ms.logger.Errorf("handler returned %v 16-bit values, " + + "expected %v", resCount, quantity) + err = ErrServerDeviceFailure + break + } + + if err != nil { + break + } + + // assemble a response PDU + res = &pdu{ + unitId: req.unitId, + functionCode: req.functionCode, + payload: []byte{0}, + } + + // byte count (2 bytes per register) + res.payload[0] = uint8(resCount * 2) + + // register values + res.payload = append(res.payload, + uint16sToBytes(BIG_ENDIAN, regs)...) + + case fcWriteSingleRegister: + var value uint16 + + if len(req.payload) != 4 { + err = ErrProtocolError + break + } + + // decode address and value fields + addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2]) + value = bytesToUint16(BIG_ENDIAN, req.payload[2:4]) + + // invoke the handler + _, err = ms.handler.HandleHoldingRegisters( + &HoldingRegistersRequest{ + ClientAddr: clientAddr, + ClientRole: clientRole, + UnitId: req.unitId, + Addr: addr, + Quantity: 1, // request for a single register + IsWrite: true, // request is a write + Args: []uint16{value}, + }) + + if err != nil { + break + } + + // assemble a response PDU + res = &pdu{ + unitId: req.unitId, + functionCode: req.functionCode, + } + + // echo the address and value in the response + res.payload = append(res.payload, + uint16ToBytes(BIG_ENDIAN, addr)...) + res.payload = append(res.payload, + uint16ToBytes(BIG_ENDIAN, value)...) + + case fcWriteMultipleRegisters: + var expectedLen int + + if len(req.payload) < 6 { + err = ErrProtocolError + break + } + + // decode address and quantity fields + addr = bytesToUint16(BIG_ENDIAN, req.payload[0:2]) + quantity = bytesToUint16(BIG_ENDIAN, req.payload[2:4]) + + // ensure the reply never exceeds the maximum PDU length and we + // never read past 0xffff + if quantity > 0x007b || quantity == 0 { + err = ErrProtocolError + break + } + if uint32(addr) + uint32(quantity) - 1 > 0xffff { + err = ErrIllegalDataAddress + break + } + + // validate the byte count field (2 bytes per register) + expectedLen = int(quantity) * 2 + + if req.payload[4] != uint8(expectedLen) { + err = ErrProtocolError + break + } + + // make sure we have enough bytes + if len(req.payload) - 5 != expectedLen { + err = ErrProtocolError + break + } + + // invoke the holding register handler + _, err = ms.handler.HandleHoldingRegisters( + &HoldingRegistersRequest{ + ClientAddr: clientAddr, + ClientRole: clientRole, + UnitId: req.unitId, + Addr: addr, + Quantity: quantity, + IsWrite: true, // this is a write request + Args: bytesToUint16s(BIG_ENDIAN, req.payload[5:]), + }) + if err != nil { + break + } + + // assemble a response PDU + res = &pdu{ + unitId: req.unitId, + functionCode: req.functionCode, + } + + // echo the address and quantity in the response + res.payload = append(res.payload, + uint16ToBytes(BIG_ENDIAN, addr)...) + res.payload = append(res.payload, + uint16ToBytes(BIG_ENDIAN, quantity)...) + + default: + res = &pdu{ + // reply with the request target unit ID + unitId: req.unitId, + // set the error bit + functionCode: (0x80 | req.functionCode), + // set the exception code to illegal function to indicate that + // the server does not know how to handle this function code. + payload: []byte{exIllegalFunction}, + } + } + + // if there was no error processing the request but the response is nil + // (which should never happen), emit a server failure exception code + // and log an error + if err == nil && res == nil { + err = ErrServerDeviceFailure + ms.logger.Errorf("internal server error (req: %v, res: %v, err: %v)", + req, res, err) + } + + // map go errors to modbus errors, unless the error is a protocol error, + // in which case close the transport and return. + if err != nil { + if err == ErrProtocolError { + ms.logger.Warningf( + "protocol error, closing link (client address: '%s')", + clientAddr) + t.Close() + return + } else { + res = &pdu{ + unitId: req.unitId, + functionCode: (0x80 | req.functionCode), + payload: []byte{mapErrorToExceptionCode(err)}, + } + } + } + + // write the response to the transport + err = t.WriteResponse(res) + if err != nil { + ms.logger.Warningf("failed to write response: %v", err) + } + + // avoid holding on to stale data + req = nil + res = nil + } + + // never reached + return +} + +// startTLS performs a full TLS handshake (with client authentication) on tcpSock +// and returns a 'wrapped' clear-text socket suitable for use by the TCP transport. +func (ms *ModbusServer) startTLS(tcpSock net.Conn) ( + tlsSock *tls.Conn, clientRole string, err error) { + var connState tls.ConnectionState + + // set a 30s timeout for the TLS handshake to complete + err = tcpSock.SetDeadline(time.Now().Add(30 * time.Second)) + if err != nil { + return + } + + // start TLS negotiation over the raw TCP connection + tlsSock = tls.Server(tcpSock, &tls.Config{ + Certificates: []tls.Certificate{ + *ms.conf.TLSServerCert, + }, + ClientCAs: ms.conf.TLSClientCAs, + // require a valid (verified) certificate from the client + // (see R-06, R-08 and R-10 of the MBAPS spec) + ClientAuth: tls.RequireAndVerifyClientCert, + // mandate TLSv1.2 or higher (see R-01 of the MBAPS spec) + MinVersion: tls.VersionTLS12, + }) + + // complete the full TLS handshake (with client cert validation) + err = tlsSock.Handshake() + if err != nil { + return + } + + // look for and extract the client's role, if any + connState = tlsSock.ConnectionState() + if len(connState.PeerCertificates) == 0 { + err = errors.New("no client certificate received") + return + } + // From the tls.ConnectionState doc: + // "The first element is the leaf certificate that the connection is + // verified against." + clientRole = ms.extractRole(connState.PeerCertificates[0]) + + return +} + +// extractRole looks for Modbus Role extensions in a certificate and returns the +// role as a string. +// If no role extension is found, a nil string is returned (R-23). +// If multiple or invalid role extensions are found, a nil string is returned (R-65, R-22). +func (ms *ModbusServer) extractRole(cert *x509.Certificate) (role string) { + var err error + var found bool + var badCert bool + + // walk through all extensions looking for Modbus Role OIDs + for _, ext := range cert.Extensions { + if ext.Id.Equal(modbusRoleOID) { + + // there must be only one role extension per cert (R-65) + if found { + ms.logger.Warning("client certificate contains more than one role OIDs") + badCert = true + break + } + found = true + + // the role extension must use UTF8String encoding (R-22) + // (the ASN1 tag for UTF8String is 0x0c) + if len(ext.Value) < 2 || ext.Value[0] != 0x0c { + badCert = true + break + } + + // extract the ASN1 string + _, err = asn1.Unmarshal(ext.Value, &role) + if err != nil { + ms.logger.Warningf("failed to decode Modbus Role extension: %v", err) + badCert = true + break + } + } + } + + // blank the role if we found more than one Role extension + if badCert { + role = "" + } + + return +} diff --git a/server_tcp_test.go b/server_tcp_test.go new file mode 100644 index 0000000..7db2f79 --- /dev/null +++ b/server_tcp_test.go @@ -0,0 +1,658 @@ +package modbus + +import ( + "testing" + "time" +) + +func TestTCPServerWithConcurrentConnections(t *testing.T) { + var server *ModbusServer + var err error + var coils []bool + var c1 *ModbusClient + var c2 *ModbusClient + var c3 *ModbusClient + var th *tcpTestHandler + + th = &tcpTestHandler{} + + server, err = NewServer(&ServerConfiguration{ + URL: "tcp://localhost:5502", + MaxClients: 2, + }, th) + if err != nil { + t.Errorf("failed to create server: %v", err) + } + + err = server.Start() + if err != nil { + t.Errorf("failed to start server: %v", err) + } + + // create 3 modbus clients + c1, err = NewClient(&ClientConfiguration{ + URL: "tcp://localhost:5502", + }) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + c2, err = NewClient(&ClientConfiguration{ + URL: "tcp://localhost:5502", + }) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + c3, err = NewClient(&ClientConfiguration{ + URL: "tcp://localhost:5502", + }) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + + // the server should have zero client connections so far + server.lock.Lock() + if len(server.tcpClients) != 0 { + t.Errorf("expected server.tcpClients to hold 0 entries, got: %v", + len(server.tcpClients)) + } + server.lock.Unlock() + + // connect client #1 + err = c1.Open() + if err != nil { + t.Errorf("c1.Connect() should have succeeded, got: %v", err) + } + c1.SetUnitId(9) + + // the server should have 1 client connection at this point + time.Sleep(time.Millisecond) + server.lock.Lock() + if len(server.tcpClients) != 1 { + t.Errorf("expected server.tcpClients to hold 1 entry, got: %v", + len(server.tcpClients)) + } + server.lock.Unlock() + + // connect client #2 + err = c2.Open() + if err != nil { + t.Errorf("c2.Connect() should have succeeded, got: %v", err) + } + c2.SetUnitId(9) + + time.Sleep(time.Millisecond) + // the server should now have 2 client connections, its maximum allowed + server.lock.Lock() + if len(server.tcpClients) != 2 { + t.Errorf("expected server.tcpClients to hold 2 entries, got: %v", + len(server.tcpClients)) + } + server.lock.Unlock() + + // connect client #3 + err = c3.Open() + if err != nil { + t.Errorf("c3.Connect() should have succeeded, got: %v", err) + } + c3.SetUnitId(9) + + // since the previous client was rejected, the active connection count + // should stay at 2 + server.lock.Lock() + if len(server.tcpClients) != 2 { + t.Errorf("expected server.tcpClients to hold 2 entries, got: %v", + len(server.tcpClients)) + } + server.lock.Unlock() + + // c1 and c2 should both be able to make requests while c3 should error out + // as it has been disconnected (conn closed) + coils, err = c1.ReadCoils(0x0000, 2) + if err != nil { + t.Errorf("c1.ReadCoils() should have succeeded, got: %v", err) + } + if coils[0] == true || coils[1] == true { + t.Errorf("expected {false, false}, got: %v", coils) + } + + coils, err = c2.ReadCoils(0x0003, 5) + if err != nil { + t.Errorf("c2.ReadCoils() should have succeeded, got: %v", err) + } + if coils[0] != false || coils[1] != false { + t.Errorf("expected {false, false}, got: %v", coils) + } + + _, err = c3.ReadCoil(0x0001) + if err == nil { + t.Errorf("c3.ReadCoil() should have failed") + } + + // close c2 and make sure the connection is freed + c2.Close() + time.Sleep(time.Millisecond) + server.lock.Lock() + if len(server.tcpClients) != 1 { + t.Errorf("expected server.tcpClients to hold 1 entry, got: %v", + len(server.tcpClients)) + } + server.lock.Unlock() + + // reconnect c2 + err = c2.Open() + if err != nil { + t.Errorf("c2.Open should have succeeded, got: %v", err) + } + + // write to the coil at address #1 + err = c2.WriteCoil(0x0001, true) + if err != nil { + t.Errorf("c2.WriteCoil() should have succeeded, got: %v", err) + } + + server.lock.Lock() + if len(server.tcpClients) != 2 { + t.Errorf("expected server.tcpClients to hold 2 entries, got: %v", + len(server.tcpClients)) + } + server.lock.Unlock() + + // check the coil value with c1 + coils, err = c1.ReadCoils(0x0000, 2) + if err != nil { + t.Errorf("c1.ReadCoils() should have succeeded, got: %v", err) + } + if coils[0] != false || coils[1] != true { + t.Errorf("expected {false, true}, got: %v", coils) + } + + // close c1 and make sure the connection is freed + c1.Close() + time.Sleep(time.Millisecond) + server.lock.Lock() + if len(server.tcpClients) != 1 { + t.Errorf("expected server.tcpClients to hold 1 entry, got: %v", + len(server.tcpClients)) + } + server.lock.Unlock() + + // stopping the server should disconnect all clients + server.Stop() + + time.Sleep(time.Millisecond) + server.lock.Lock() + if len(server.tcpClients) != 0 { + t.Errorf("expected server.tcpClients to hold 0 entries, got: %v", + len(server.tcpClients)) + } + server.lock.Unlock() + + // c2 should have been disconnected + coils, err = c2.ReadCoils(0x0003, 5) + if err == nil { + t.Errorf("c2.ReadCoils() should have failed") + } + + return +} + +func TestTCPServerCoilsAndDiscreteInputs(t *testing.T) { + var server *ModbusServer + var err error + var coils []bool + var dis []bool + var client *ModbusClient + var th *tcpTestHandler + + th = &tcpTestHandler{} + + server, err = NewServer(&ServerConfiguration{ + URL: "tcp://localhost:5504", + MaxClients: 2, + }, th) + if err != nil { + t.Errorf("failed to create server: %v", err) + } + + err = server.Start() + if err != nil { + t.Errorf("failed to start server: %v", err) + } + + client, err = NewClient(&ClientConfiguration{ + URL: "tcp://localhost:5504", + }) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + + err = client.Open() + if err != nil { + t.Errorf("client.Open() should have succeeded, got: %v", err) + } + client.SetUnitId(9) + + // make sure both coils and discrete inputs are all false/0 + coils, err = client.ReadCoils(0x0000, 10) + if err != nil { + t.Errorf("client.ReadCoils() should have succeeded, got: %v", err) + } + for i := 0; i < 10; i++ { + if coils[i] != false { + t.Errorf("expected coil at addr 0x%04x to be false", i) + } + } + + dis, err = client.ReadDiscreteInputs(0x0000, 10) + if err != nil { + t.Errorf("client.ReadDiscreteInputs() should have succeeded, got: %v", err) + } + for i := 0; i < 10; i++ { + if dis[i] != false { + t.Errorf("expected discrete input at addr 0x%04x to be false", i) + } + } + + // set discrete inputs to random values + th.di = [10]bool{ + false, false, false, true, false, true, true, true, true, true, + } + + // read the discrete inputs again + dis, err = client.ReadDiscreteInputs(0x0000, 10) + if err != nil { + t.Errorf("client.ReadDiscreteInput() should have succeeded, got: %v", err) + } + for i, b := range [10]bool{ + false, false, false, true, false, true, true, true, true, true, + } { + if dis[i] != b { + t.Errorf("expected discrete input at addr 0x%04x to be %v", i, b) + } + } + + // reading past the array size should return ErrIllegalDataAddress + _, err = client.ReadDiscreteInputs(0x000a, 1) + if err != ErrIllegalDataAddress { + t.Errorf("expected ErrIllegalDataAddress, got: %v", err) + } + _, err = client.ReadCoils(0x000a, 1) + if err != ErrIllegalDataAddress { + t.Errorf("expected ErrIllegalDataAddress, got: %v", err) + } + _, err = client.ReadDiscreteInputs(0x8, 3) + if err != ErrIllegalDataAddress { + t.Errorf("expected ErrIllegalDataAddress, got: %v", err) + } + _, err = client.ReadCoils(0x8, 3) + if err != ErrIllegalDataAddress { + t.Errorf("expected ErrIllegalDataAddress, got: %v", err) + } + + // the coils shouldn't have changed + coils, err = client.ReadCoils(0x0000, 10) + if err != nil { + t.Errorf("client.ReadCoils() should have succeeded, got: %v", err) + } + for i := 0; i < 10; i++ { + if coils[i] != false { + t.Errorf("expected coil at addr 0x%04x to be false", i) + } + } + + // write to a single coil + err = client.WriteCoil(0x0004, true) + if err != nil { + t.Errorf("client.WriteCoil() should have succeeded, got: %v", err) + } + + // make sure it has been written to + coils, err = client.ReadCoils(0x0003, 3) + if err != nil { + t.Errorf("client.ReadCoils() should have succeeded, got: %v", err) + } + for i, v := range []bool{false, true, false,} { + if coils[i] != v { + t.Errorf("expected coil at addr 0x%04x to be %v", 3 + i, v) + } + } + + // write to multiple coils at once + err = client.WriteCoils(0x0005, []bool{ + true, false, true, true, + }) + if err != nil { + t.Errorf("client.WriteCoils() should have succeeded, got: %v", err) + } + + // make sure the write went through + coils, err = client.ReadCoils(0x0005, 4) + if err != nil { + t.Errorf("client.ReadCoils() should have succeeded, got: %v", err) + } + for i, v := range []bool{true, false, true, true,} { + if coils[i] != v { + t.Errorf("expected coil at addr 0x%04x to be %v", 3 + i, v) + } + } + + // switch to another unit ID and make sure both coil and discrete input operations + // return ErrIllegalFunction + client.SetUnitId(5) + err = client.WriteCoils(0x0005, []bool{ + true, false, true, true, + }) + if err != ErrIllegalFunction { + t.Errorf("client.WriteCoils() should have returned ErrIllegalFunction, got: %v", err) + } + err = client.WriteCoil(0x0005, false) + if err != ErrIllegalFunction { + t.Errorf("client.WriteCoil() should have returned ErrIllegalFunction, got: %v", err) + } + coils, err = client.ReadCoils(0x0005, 1) + if err != ErrIllegalFunction { + t.Errorf("client.ReadCoils() should have returned ErrIllegalFunction, got: %v", err) + } + coils, err = client.ReadDiscreteInputs(0x0005, 1) + if err != ErrIllegalFunction { + t.Errorf("client.ReadDiscreteInputs() should have returned ErrIllegalFunction, got: %v", err) + } + + client.Close() + server.Stop() + + return +} + +func TestTCPServerHoldingAndInputRegisters(t *testing.T) { + var server *ModbusServer + var err error + var client *ModbusClient + var th *tcpTestHandler + var regs []uint16 + + th = &tcpTestHandler{} + + server, err = NewServer(&ServerConfiguration{ + URL: "tcp://localhost:5504", + MaxClients: 2, + }, th) + if err != nil { + t.Errorf("failed to create server: %v", err) + } + + err = server.Start() + if err != nil { + t.Errorf("failed to start server: %v", err) + } + + client, err = NewClient(&ClientConfiguration{ + URL: "tcp://localhost:5504", + }) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + + err = client.Open() + if err != nil { + t.Errorf("client.Open() should have succeeded, got: %v", err) + } + client.SetUnitId(9) + + // all 10 input registers should be 0x0000 + regs, err = client.ReadRegisters(0x0000, 10, INPUT_REGISTER) + if err != nil { + t.Errorf("client.ReadRegisters() should have succeeded, got: %v", err) + } + for i := 0; i < 10; i++ { + if regs[i] != 0x0000 { + t.Errorf("expected 0x0000 at position %v, got: 0x%04x", i, regs[i]) + } + } + + // assign some values to the handler's input registers + for i := range th.input { + th.input[i] = 0xa710 + uint16(i) + } + + regs, err = client.ReadRegisters(0x0000, 10, INPUT_REGISTER) + if err != nil { + t.Errorf("client.ReadRegisters() should have succeeded, got: %v", err) + } + for i := 0; i < 10; i++ { + if regs[i] != 0xa710 + uint16(i) { + t.Errorf("expected 0x%04x at position %v, got: 0x%04x", + 0xa710 + uint16(i), i, regs[i]) + } + } + + // reading addr 0x0009 (the very last register) should succeed + regs, err = client.ReadRegisters(0x0009, 1, INPUT_REGISTER) + if err != nil { + t.Errorf("client.ReadRegisters() should have succeeded, got: %v", err) + } + if regs[0] != 0xa719 { + t.Errorf("expected 0xa719 at address 9, saw: 0x%04x", regs[0]) + } + + // reading past address 0x000a should fail + regs, err = client.ReadRegisters(0x0001, 10, INPUT_REGISTER) + if err != ErrIllegalDataAddress { + t.Errorf("client.ReadRegisters() should have returned ErrIllegalDataAddress, got: %v", err) + } + regs, err = client.ReadRegisters(0x0000, 11, INPUT_REGISTER) + if err != ErrIllegalDataAddress { + t.Errorf("client.ReadRegisters() should have returned ErrIllegalDataAddress, got: %v", err) + } + + // all 10 holding registers should still be 0x0000 + regs, err = client.ReadRegisters(0x0000, 10, HOLDING_REGISTER) + if err != nil { + t.Errorf("client.ReadRegisters() should have succeeded, got: %v", err) + } + for i := 0; i < 10; i++ { + if regs[i] != 0x0000 { + t.Errorf("expected 0x0000 at position %v, got: 0x%04x", i, regs[i]) + } + } + + // write to a single valid register (with opcode 0x06) + err = client.WriteRegister(0x0007, 0xfea1) + if err != nil { + t.Errorf("client.WriteRegister() should have succeeded, got: %v", err) + } + + // make sure it has been written to + regs, err = client.ReadRegisters(0x0005, 5, HOLDING_REGISTER) + if err != nil { + t.Errorf("client.ReadRegisters() should have succeeded, got: %v", err) + } + for i := 0; i < 5; i++ { + if i != 2 && regs[i] != 0x0000 { + t.Errorf("expected 0x0000 at position %v, got: 0x%04x", i, regs[i]) + } + if i == 2 && regs[i] != 0xfea1 { + t.Errorf("expected 0xfea1 at position %v, got: 0x%04x", i, regs[i]) + } + } + + // check values in the handler as well + for i := 0; i < 10; i++ { + if i != 7 && th.holding[i] != 0x0000 { + t.Errorf("expected 0x0000 at handler index %v, got: 0x%04x", i, regs[i]) + } + if i == 7 && th.holding[i] != 0xfea1 { + t.Errorf("expected 0xfea1 at handler index %v, got: 0x%04x", i, regs[i]) + } + } + + // write multiple registers at once (with function code 0x10) + err = client.WriteRegisters(0x0001, []uint16{ + 0x0c11, 0x0c22, 0x0c33, 0x0c44, + 0x0c55, 0x0c66, 0x0c77, 0x0c88, + 0x0c99, + }) + if err != nil { + t.Errorf("client.WriteRegisters() should have succeeded, got: %v", err) + } + + // write to a single valid register (with opcode 0x06) + err = client.WriteRegister(0x0000, 0x0c00) + if err != nil { + t.Errorf("client.WriteRegister() should have succeeded, got: %v", err) + } + + // make sure they have all been written to + regs, err = client.ReadRegisters(0x0000, 10, HOLDING_REGISTER) + if err != nil { + t.Errorf("client.ReadRegisters() should have succeeded, got: %v", err) + } + for i := 0; i < 10; i++ { + if regs[i] != 0x0c00 + uint16(0x11 * i) { + t.Errorf("expected ox%04x at position %v, got: 0x%04x", + 0x0c00 + uint16(0x11 * i), i, regs[i]) + } + } + + // check values in the handler as well + for i := 0; i < 10; i++ { + if th.holding[i] != 0x0c00 + uint16(0x11 * i) { + t.Errorf("expected 0xfea1 at handler index %v, got: 0x%04x", i, regs[i]) + } + } + + // reading addr 0x0009 (the very last register) should succeed + regs, err = client.ReadRegisters(0x0009, 1, HOLDING_REGISTER) + if err != nil { + t.Errorf("client.ReadRegisters() should have succeeded, got: %v", err) + } + if regs[0] != 0x0c99 { + t.Errorf("expected 0x0c99 at address 9, saw: 0x%04x", regs[0]) + } + + // reading past address 0x000a should fail + regs, err = client.ReadRegisters(0x0001, 10, HOLDING_REGISTER) + if err != ErrIllegalDataAddress { + t.Errorf("client.ReadRegisters() should have returned ErrIllegalDataAddress, got: %v", err) + } + regs, err = client.ReadRegisters(0x0000, 11, HOLDING_REGISTER) + if err != ErrIllegalDataAddress { + t.Errorf("client.ReadRegisters() should have returned ErrIllegalDataAddress, got: %v", err) + } + + // switch to another unit ID and make sure both holding and input register operations + // return ErrIllegalFunction + client.SetUnitId(2) + err = client.WriteRegisters(0x0005, []uint16{ + 0x0000, 0x0001, + }) + if err != ErrIllegalFunction { + t.Errorf("client.WriteRegisters() should have returned ErrIllegalFunction, got: %v", err) + } + err = client.WriteRegister(0x0001, 0xffff) + if err != ErrIllegalFunction { + t.Errorf("client.WriteRegister() should have returned ErrIllegalFunction, got: %v", err) + } + regs, err = client.ReadRegisters(0x0005, 1, HOLDING_REGISTER) + if err != ErrIllegalFunction { + t.Errorf("client.ReadRegisters() should have returned ErrIllegalFunction, got: %v", err) + } + regs, err = client.ReadRegisters(0x0005, 1, INPUT_REGISTER) + if err != ErrIllegalFunction { + t.Errorf("client.ReadRegisters() should have returned ErrIllegalFunction, got: %v", err) + } + + client.Close() + server.Stop() + + return +} + +type tcpTestHandler struct { + coils [10]bool + di [10]bool + input [10]uint16 + holding [10]uint16 +} + +func (th *tcpTestHandler) HandleCoils(req *CoilsRequest) (res []bool, err error) { + if req.UnitId != 9 { + // only reply to unit ID #9 + err = ErrIllegalFunction + return + } + + if req.Addr + req.Quantity > uint16(len(th.coils)) { + err = ErrIllegalDataAddress + return + } + + for i := 0; i < int(req.Quantity); i++ { + if req.IsWrite { + th.coils[int(req.Addr) + i] = req.Args[i] + } + res = append(res, th.coils[int(req.Addr) + i]) + } + + return +} + +func (th *tcpTestHandler) HandleDiscreteInputs(req *DiscreteInputsRequest) (res []bool, err error) { + if req.UnitId != 9 { + // only reply to unit ID #9 + err = ErrIllegalFunction + return + } + + if req.Addr + req.Quantity > uint16(len(th.di)) { + err = ErrIllegalDataAddress + return + } + + for i := 0; i < int(req.Quantity); i++ { + res = append(res, th.di[int(req.Addr) + i]) + } + + return +} + +func (th *tcpTestHandler) HandleHoldingRegisters(req *HoldingRegistersRequest) (res []uint16, err error) { + if req.UnitId != 9 { + // only reply to unit ID #9 + err = ErrIllegalFunction + return + } + + if req.Addr + req.Quantity > uint16(len(th.holding)) { + err = ErrIllegalDataAddress + return + } + + for i := 0; i < int(req.Quantity); i++ { + if req.IsWrite { + th.holding[int(req.Addr) + i] = req.Args[i] + } + res = append(res, th.holding[int(req.Addr) + i]) + } + + return +} + +func (th *tcpTestHandler) HandleInputRegisters(req *InputRegistersRequest) (res []uint16, err error) { + if req.UnitId != 9 { + // only reply to unit ID #9 + err = ErrIllegalFunction + return + } + + if req.Addr + req.Quantity > uint16(len(th.input)) { + err = ErrIllegalDataAddress + return + } + + for i := 0; i < int(req.Quantity); i++ { + res = append(res, th.input[int(req.Addr) + i]) + } + + return +} diff --git a/server_tls_test.go b/server_tls_test.go new file mode 100644 index 0000000..9a6037a --- /dev/null +++ b/server_tls_test.go @@ -0,0 +1,572 @@ +package modbus + +import ( + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "testing" + "time" +) + +const ( + clientCertWithRoleOID string = ` +-----BEGIN CERTIFICATE----- +MIIGCDCCA/CgAwIBAgIUdNWUjckypyaWon4eQm8dKWHQPBEwDQYJKoZIhvcNAQEL +BQAwJjEkMCIGA1UEAwwbVEVTVCBDTElFTlQgQ0VSVCBETyBOT1QgVVNFMB4XDTIw +MDgyODE4MDIyMVoXDTQwMDgyMzE4MDIyMVowJjEkMCIGA1UEAwwbVEVTVCBDTElF +TlQgQ0VSVCBETyBOT1QgVVNFMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKC +AgEAr9UnAZT8WDYOuI+0cxFAUnOw422osdhlvb7gGEZwwHMOe4k+D0PfQVFD0ctd +ZMBVL4O/YWOuKkpUlNBYFquu/eOuFVVdPs81y1u8EZ4kpYdeTiAgE5abANlMvnSH +eSIyFAeU0qS5UNKrYiOwJzKgNZ7SLbjZxFvdirjhSX7Y95bZ9O5K4x1MsB7dUYRz +weH5jHyOgqgj2Gccxkohg1npscDzFvyy73nJWhHCFXj7zhfLpJKHhu/9v7jEZkuT +Nl03XrsWjEWRy3YoW2xG8elvdD6LQAj2trh9bcq9h3UJdbtduLyLpcHIwNJtuCOx +Gek7kyGLhh67FeINXKrdEpwQuSdJw8DVARP3D+ltjpfGZeZN2urDvrijz+5i5DIx +O8QlqoEm5LWf232dKEPZcqw8Uz4SxRYgc8qcw9HDWaKHDkpddAL/D+EYt/LHMvTt +jJJ7IrgX20eo/QLnWwxcWOfc2YrrGAXnghKw2O3DqrOT5t5dK/hz/OQwPMGjN1pj +2OcYwdLvykqIS387DXeIzaiaxSIIwo6NV8uWxcQIr65Ajt8nTygHifmp3FRicrgO +Pycoww3j73Y61nYVSQ9Tpjg3I6OHQB7gW+ymb9QwOJ6/vs/DzDF1Meaw6xKKbF8n +A/JUxF0NVfdB+DafVP/MageokvpzMtRKH5Qp/GOJGpF/DXsCAwEAAaOCASwwggEo +MB0GA1UdDgQWBBSMyqL/JXXHSvl4tm6jetNvViTfzzAfBgNVHSMEGDAWgBSMyqL/ +JXXHSvl4tm6jetNvViTfzzAPBgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBSMyqL/ +JXXHSvl4tm6jetNvViTfzzBhBgNVHSMEWjBYgBSMyqL/JXXHSvl4tm6jetNvViTf +z6EqpCgwJjEkMCIGA1UEAwwbVEVTVCBDTElFTlQgQ0VSVCBETyBOT1QgVVNFghR0 +1ZSNyTKnJpaifh5Cbx0pYdA8ETALBgNVHQ8EBAMCAqQwFgYDVR0lAQH/BAwwCgYI +KwYBBQUHAwIwEgYDVR0TAQH/BAgwBgEB/wIBADAaBgsrBgEEAYOJDIYiAQQLDAlv +cGVyYXRvcjIwDQYJKoZIhvcNAQELBQADggIBAF1czPdpHadmotgQTvtf/xoIr23Q +UqiyzUtpIwo+p/uZKRR9w0dVOpamoehbLuN4r8lb0EBKG/UbXaUpQozKBxUaIUOL +ZRKwvWCTaJFVLp4qqW7R8sxDDRovmndnBD98CkMOD7rWbHByfoVsgOYJ2QZLED84 +RaZDuRysnw4Z6spoE4krL3Aabp4z4t7CGPhZIVyLGBwjqXPFhS7BMLWEztVBEuxc +CKR9iz4+93flid1dTB3/NRYmEFpGfLShRkOIslUZtdnmSkdZ+vIhJeK14QP0o1Hf +gZmRpPHsEGAQTg5lbRqbz3n8hd5SeVX1SnL4orHqE2Xk/8zCb+uLl3nc78pxkDYH +t758FGkcCy2QvAxVqd3++ek4wH9VMBpD+Ds536eyagygWNaQwAqb2/LWwkodFCUj +VFkAQj1nLT9YmzDvG2VRNH58uuFdSwv6GwFda0tqs1PzGbdN7G6VtUMobu/v71kd +kIrWrPzOzNCR0Pn2JZqervWP0956W3Am2PJqG5o41qIjSrb8vzxpnlVHVjrhoKx9 +8GCaA/6WsQrH09Rai7wDKiRD/zyUEWfTAUMpNPYFPl092Khb9azzp5aj4OHU0Z2E +Fd5StjPuFnSwAIqv3IdthbHPz+ifOyRLxEYOaXImNJFWRyLdcrn7yPZ+X6+IjBJe +hG79y2z0UfKJstN+ +-----END CERTIFICATE----- +` + + clientKeyWithRoleOID string = ` +-----BEGIN PRIVATE KEY----- +MIIJQgIBADANBgkqhkiG9w0BAQEFAASCCSwwggkoAgEAAoICAQCv1ScBlPxYNg64 +j7RzEUBSc7Djbaix2GW9vuAYRnDAcw57iT4PQ99BUUPRy11kwFUvg79hY64qSlSU +0FgWq679464VVV0+zzXLW7wRniSlh15OICATlpsA2Uy+dId5IjIUB5TSpLlQ0qti +I7AnMqA1ntItuNnEW92KuOFJftj3ltn07krjHUywHt1RhHPB4fmMfI6CqCPYZxzG +SiGDWemxwPMW/LLveclaEcIVePvOF8ukkoeG7/2/uMRmS5M2XTdeuxaMRZHLdihb +bEbx6W90PotACPa2uH1tyr2HdQl1u124vIulwcjA0m24I7EZ6TuTIYuGHrsV4g1c +qt0SnBC5J0nDwNUBE/cP6W2Ol8Zl5k3a6sO+uKPP7mLkMjE7xCWqgSbktZ/bfZ0o +Q9lyrDxTPhLFFiBzypzD0cNZoocOSl10Av8P4Ri38scy9O2MknsiuBfbR6j9Audb +DFxY59zZiusYBeeCErDY7cOqs5Pm3l0r+HP85DA8waM3WmPY5xjB0u/KSohLfzsN +d4jNqJrFIgjCjo1Xy5bFxAivrkCO3ydPKAeJ+ancVGJyuA4/JyjDDePvdjrWdhVJ +D1OmODcjo4dAHuBb7KZv1DA4nr++z8PMMXUx5rDrEopsXycD8lTEXQ1V90H4Np9U +/8xqB6iS+nMy1EoflCn8Y4kakX8NewIDAQABAoICABMzrufQUmJ7vM3Q+77ZMnIO +qlGb5yFM5Yd8MdLU1nld10YMbdeS7O2gJ0zg7ZkUG/ltZNgI37tElMoPmp8XLqwR +UjCIOv+h91j28qnl4FCnYNgdUANzngfQsz3VUfobjuZ7EXiTfp1h9E9qYFFXiQFy +D7foiPeVpLMCj6/MB3u6YKEL6Oe2imptZHQDh/SzbeI2tAV2wTtfv1e0PsauagP8 +c0+eVxgp76BDcjOQG8ec96NIUT6eNNLcJa6aMEBum55fxg2Zh1t10uBxCapfeMl0 +Dxb2I6M+sIvt6RbC5D6UMJ79EC8Q45CTKmJCm5Od0eC2eBs0fe/c2OK20h+3JWg0 +eybKuX1GXTSkd5padBKZPJumINFIDUlaiQrRCFr4ZpcUJqBcrCkkRq2wyvuwAEiZ +IvNqcJqjqzZFhjhZKrxVIv3C9av6v0JQLdrbquYZbRji6KaSCcU2xCdbOmSvcr9w +909Lz9gwdCFYVWWRCLmfA8FSR+hDGLW8fT2CfbijYYBpGR0W9zLXCR7DD7Hx0bYj +ZJg+ri3Yw8yZ1mTT6ZLmlz3HEtSNJD1kD/QcpCanwvkGYXRO5FOVaMXGx5QVWs33 +fS9ChLesujKQ9ye5jFc5tw1rmClcTNtWipllctdPZGzfaEs6a+LfTbpB1/zd2eBK +XkeMAYWp+ES8XXYFUToRAoIBAQDjj3MdNR5sM9m87E/IS3+1GulkfQ/M39BFmCtc +O7gvWiK7asrP2TFU9vjU9zEomo4ABCYKNaJtMqtGxH9EapsTepdjb1UrjYpmKyuJ +SyTk8+iWRLQa3RnyWkE2MXqIKQQH37uB5/Id009is5f1X89mfGOXQuldDxQy3ygC +OeYOAVOl5fZH0NqikTDnpaViKyAqhp7bbRG0KBvqtYBj/k930auL3Ls6eONMlFfB +9IzYeM3lcbiYmfMhOPuYReFgpC62SQDWBALQagcS2Uh4vOP2fSd6ragi8telBeLW +fWl5TSy4tNurbgtYh8WdizBDO+DD0is+b8HSiWUPJfPH7QDDAoIBAQDFzrq8JvMm +cJWnSqsCnWD6CRj8w0CvXe9IjwK7cquUc/xsyZFbHbKe1kKf0+UxywbT+1uSaJ19 +ZbvyLAi16+S27aHI5SX3R2SnZQPA2GOWqimSHu5HAMfTOsqqbkLImRUnzl2Vl3lW ++AN/4FMvAlA1HotI3EiuQxLWstG5RNeo58sVMobaiZd7+xnBsov7MKgKC90eBmTR +uxbQmPJUFLhefpTly1E72rbYwZ2a2AOBq/WJ74+Gb7A6DoQQmSFqRHe5X1e5v8L4 +nUwiUd60J9ACOMKYCGzPkwXSPvfqcmuSL1KKupsJAcVV+AcC4qmNPfDWI6A87aha +b4a+78g4u3TpAoIBACTZ1SV0tbGGEAu1JRJlj4/PhN4+FnHyCLNMejEchq48ZYV+ +PMu9+2wr9o3eXfqaVMaR5Wsf1mbinrP+HDIDJYvY/W0f2WYNLM1wzkMUhSwCh7bV +92imR45kqUzSZGpqYfm4dJAL9Lx5vNBaDxCwbFDHcgVL06i7SWUXmE4L/EJmWppy +DBkDLHTJGGda/tZP74yTcmRMXGKVYDf5HoqS42Ge9a3XmAZXD1AWccO6C5j+rzEp +4l/sBmBp7uxw3Jee3uWsGtONoLsJgI2/3CmZRT1kdSE7wA+wzdUuh9Z+RrdbFRPw +TeaMEpBKpGjn4m/w4Ww0u8YHqRakI1Z5qenFaqsCggEAH3WUf04Wh7uKIZQfhIfx +H3MI9VI8XGetIbYU8ij3nuGfeNHJ+1rKyLY83Fx/7B5lFJu6YZufyIzAinB0ZjKB +KpK6k0/WbPB+0pyfLzF7DUA84k9nCAXYwgBssRReLLckBTOt8JeppapGLDVKJYTR +qtETx9+483YZben8ruGDBwruYo2pouIVJJO38fVqi+WeJBLk9NyBdlWx+DUK/VJa +TDUHi1B9t+49/FU2sqS+UgY+Q9TE19W1ilY6rMUd6l+/Rs0iD5mu8YlazW6F49Md +Iu1SDYnxfEXevCRlm3TdJN+/2e55r8IHV3fd7ZiM7Li4L+Z0mpwVlWR9YqqSBmvR +2QKCAQEAv1P9zlYiOjK5MlpP8rfWyb2CuUCT3DG9k7+RZMPL6QCp5Fc/xINsttJc +bPSwhuWjYYE2DpenZAcn4Mf8JhhdUf+yijLVZYSDINgfUMgrmSTETRB4X28KYrGJ +UG3kz2IQnbIfPPrekFcL87h6dc88lfq5U/inPqSoQYdE99XD8iTY3Tb2ESD8B7Zk +Xh9uF519h8lnUA+/O6r3aLJ/d0ApKoLWancvenrkwe3jgc1MGUG0kjNLNCN310YW +lKNiMZCOhCMEGxo7pm1KBpPxxb+8Mo2ydxC2s4jhX748aMe1MvlTg5+IYUkqVDBq +isPLG4c6aPGxSbHirNfl6tBSngDy+A== +-----END PRIVATE KEY----- +` +) + +// TestTLSServer tests the TLS layer of the modbus server. +func TestTLSServer(t *testing.T) { + var err error + var server *ModbusServer + var serverKeyPair tls.Certificate + var client1KeyPair tls.Certificate + var client2KeyPair tls.Certificate + var clientCp *x509.CertPool + var serverCp *x509.CertPool + var th *tlsTestHandler + var c1 *ModbusClient + var c2 *ModbusClient + var regs []uint16 + var coils []bool + + th = &tlsTestHandler{} + + // load server keypair (from client_tls_test.go) + serverKeyPair, err = tls.X509KeyPair([]byte(serverCert), []byte(serverKey)) + if err != nil { + t.Errorf("failed to load test server key pair: %v", err) + return + } + + // load the first client keypair (from client_tls_test.go) + // this client cert doesn't have any Modbus Role extension + client1KeyPair, err = tls.X509KeyPair([]byte(clientCert), []byte(clientKey)) + if err != nil { + t.Errorf("failed to load test client key pair: %v", err) + return + } + + // load the second client keypair (defined above) + // this client cert has an "operator2" Modbus Role extension + client2KeyPair, err = tls.X509KeyPair( + []byte(clientCertWithRoleOID), []byte(clientKeyWithRoleOID)) + if err != nil { + t.Errorf("failed to load test client key pair: %v", err) + return + } + + // load the server cert into the client CA cert pool to get the server cert + // accepted by clients + clientCp = x509.NewCertPool() + if !clientCp.AppendCertsFromPEM([]byte(serverCert)) { + t.Errorf("failed to load test server cert into cert pool") + } + + // start with an empty server cert pool initially to reject the client + // certificate + serverCp = x509.NewCertPool() + + server, err = NewServer(&ServerConfiguration{ + URL: "tcp+tls://localhost:5802", + MaxClients: 2, + TLSServerCert: &serverKeyPair, + TLSClientCAs: serverCp, + }, th) + if err != nil { + t.Errorf("failed to create server: %v", err) + } + + err = server.Start() + if err != nil { + t.Errorf("failed to start server: %v", err) + } + + // create 2 modbus clients + c1, err = NewClient(&ClientConfiguration{ + URL: "tcp+tls://localhost:5802", + TLSClientCert: &client1KeyPair, + TLSRootCAs: clientCp, + }) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + c2, err = NewClient(&ClientConfiguration{ + URL: "tcp+tls://localhost:5802", + TLSClientCert: &client2KeyPair, + TLSRootCAs: clientCp, + }) + if err != nil { + t.Errorf("failed to create client: %v", err) + } + + // attempt to connect and use the first client. since its cert + // is not trusted by the server, a TLS error should occur on the first + // request. + err = c1.Open() + if err != nil { + t.Errorf("c1.Open() should have succeeded") + } + coils, err = c1.ReadCoils(0, 5) + if err == nil { + t.Error("c1.ReadCoils() should have failed") + } + c1.Close() + + // now place both client certs in the server's authorized client list + // to get them past the TLS client cert validation procedure + if !serverCp.AppendCertsFromPEM([]byte(clientCert)) { + t.Errorf("failed to load client#1 cert into cert pool") + } + if !serverCp.AppendCertsFromPEM([]byte(clientCertWithRoleOID)) { + t.Errorf("failed to load client#2 cert into cert pool") + } + + // connect both clients: should succeed + err = c1.Open() + if err != nil { + t.Error("c1.Open() should have succeeded") + } + + err = c2.Open() + if err != nil { + t.Error("c2.Open() should have succeeded") + } + + // client #2 (with 'operator2' role) should have read/write access to coils while + // client #1 (without role) should only be able to read. + err = c1.WriteCoil(0, true) + if err != ErrIllegalFunction { + t.Errorf("c1.WriteCoil() should have failed with %v, got: %v", + ErrIllegalFunction, err) + } + + coils, err = c1.ReadCoils(0, 5) + if err != nil { + t.Errorf("c1.ReadCoils() should have succeeded, got: %v", err) + } + if coils[0] { + t.Errorf("coils[0] should have been false") + } + + err = c2.WriteCoil(0, true) + if err != nil { + t.Errorf("c2.WriteCoil() should have succeeded, got: %v", err) + } + + coils, err = c2.ReadCoils(0, 5) + if err != nil { + t.Errorf("c2.ReadCoils() should have succeeded, got: %v", err) + } + if !coils[0] { + t.Errorf("coils[0] should have been true") + } + + coils, err = c1.ReadCoils(0, 5) + if err != nil { + t.Errorf("c1.ReadCoils() should have succeeded, got: %v", err) + } + if !coils[0] { + t.Errorf("coils[0] should have been true") + } + + // client #1 should only be allowed access to holding registers of unit id #1 + // while client#2 should be allowed access to holding registers of unit ids #1 and #4 + c1.SetUnitId(1) + err = c1.WriteRegister(2, 100) + if err != nil { + t.Errorf("c1.WriteRegister() should have succeeded, got: %v", err) + } + + c1.SetUnitId(4) + err = c1.WriteRegister(2, 200) + if err != ErrIllegalFunction { + t.Errorf("c1.WriteRegister() should have failed with %v, got: %v", + ErrIllegalFunction, err) + } + + c2.SetUnitId(1) + regs, err = c2.ReadRegisters(1, 2, HOLDING_REGISTER) + if err != nil { + t.Errorf("c2.ReadRegisters() should have succeeded, got: %v", err) + } + if regs[0] != 0 || regs[1] != 100 { + t.Errorf("unexpected register values: %v", regs) + } + + c2.SetUnitId(4) + err = c2.WriteRegister(2, 200) + if err != nil { + t.Errorf("c2.WriteRegister() should have succeeded, got: %v", err) + } + + regs, err = c2.ReadRegisters(1, 2, HOLDING_REGISTER) + if err != nil { + t.Errorf("c2.ReadRegisters() should have succeeded, got: %v", err) + } + if regs[0] != 0 || regs[1] != 200 { + t.Errorf("unexpected register values: %v", regs) + } + + // close the server and all client connections + server.Stop() + + // make sure all underlying TCP client connections have been freed + time.Sleep(10 * time.Millisecond) + server.lock.Lock() + if len(server.tcpClients) != 0 { + t.Errorf("expected 0 client connections, saw: %v", len(server.tcpClients)) + } + server.lock.Unlock() + + // cleanup + c1.Close() + c2.Close() + + return +} + +type tlsTestHandler struct { + coils [10]bool + holdingId1 [10]uint16 + holdingId4 [10]uint16 +} + +func (th *tlsTestHandler) HandleCoils(req *CoilsRequest) (res []bool, err error) { + // coils access is allowed to any client with a valid cert, but + // the "operator2" role is required to write + if req.IsWrite && req.ClientRole != "operator2" { + err = ErrIllegalFunction + return + } + + if req.Addr + req.Quantity > uint16(len(th.coils)) { + err = ErrIllegalDataAddress + return + } + + for i := 0; i < int(req.Quantity); i++ { + if req.IsWrite { + th.coils[int(req.Addr) + i] = req.Args[i] + } + res = append(res, th.coils[int(req.Addr) + i]) + } + + return +} + +func (th *tlsTestHandler) HandleDiscreteInputs(req *DiscreteInputsRequest) (res []bool, err error) { + // there are no digital inputs on this device + err = ErrIllegalDataAddress + + return +} + +func (th *tlsTestHandler) HandleHoldingRegisters(req *HoldingRegistersRequest) (res []uint16, err error) { + // gate unit id #4 behind the "operator2" role while access to unit id #1 + // is allowed to any valid cert + if req.UnitId == 0x04 { + if req.ClientRole != "operator2" { + err = ErrIllegalFunction + return + } + + if req.Addr + req.Quantity > uint16(len(th.holdingId4)) { + err = ErrIllegalDataAddress + return + } + + for i := 0; i < int(req.Quantity); i++ { + if req.IsWrite { + th.holdingId4[int(req.Addr) + i] = req.Args[i] + } + res = append(res, th.holdingId4[int(req.Addr) + i]) + } + } else if req.UnitId == 0x01 { + if req.Addr + req.Quantity > uint16(len(th.holdingId1)) { + err = ErrIllegalDataAddress + return + } + + for i := 0; i < int(req.Quantity); i++ { + if req.IsWrite { + th.holdingId1[int(req.Addr) + i] = req.Args[i] + } + res = append(res, th.holdingId1[int(req.Addr) + i]) + } + } else { + err = ErrIllegalFunction + return + } + + return +} + +func (th *tlsTestHandler) HandleInputRegisters(req *InputRegistersRequest) (res []uint16, err error) { + // there are no inputs registers on this device + err = ErrIllegalDataAddress + + return +} + +func TestServerExtractRole(t *testing.T) { + var ms *ModbusServer + var pemBlock *pem.Block + var x509Cert *x509.Certificate + var err error + var role string + + ms = &ModbusServer{ + logger: newLogger("test-server-role-extraction", nil), + } + + // load a client cert without role OID + pemBlock, _ = pem.Decode([]byte(clientCert)) + if err != nil { + t.Errorf("failed to decode client cert: %v", err) + return + } + + x509Cert, err = x509.ParseCertificate(pemBlock.Bytes) + if err != nil { + t.Errorf("failed to parse client cert: %v", err) + return + } + + // calling extractRole on a cert without role extension should return an + // empty string (see R-23 of the MBAPS spec) + role = ms.extractRole(x509Cert) + if role != "" { + t.Errorf("role should have been empty, got: '%s'", role) + } + + // load a certificate with a single role extension of "operator2" + pemBlock, _ = pem.Decode([]byte(clientCertWithRoleOID)) + if err != nil { + t.Errorf("failed to decode client cert: %v", err) + return + } + + x509Cert, err = x509.ParseCertificate(pemBlock.Bytes) + if err != nil { + t.Errorf("failed to parse client cert: %v", err) + return + } + + role = ms.extractRole(x509Cert) + if role != "operator2" { + t.Errorf("role should have been 'operator2', got: '%s'", role) + } + + // build a certificate with multiple Modbus Role extensions: they should + // all be rejected + x509Cert = &x509.Certificate{ + Extensions: []pkix.Extension{ + { + Id: modbusRoleOID, + Value: []byte{ + 0x0c, 0x04, 0x66, 0x77, 0x67, 0x78, + // ^ ASN1:UTF8String + // ^ length + // ^ 4-byte string 'fwgx' + }, + }, + { + Id: modbusRoleOID, + Value: []byte{ + 0x0c, 0x02, 0x66, 0x67, + // ^ ASN1:UTF8String + // ^ length + // ^ 2-byte string 'fwwf' + }, + }, + }, + } + + role = ms.extractRole(x509Cert) + if role != "" { + t.Errorf("role should have been empty, got: '%s'", role) + } + + // build a certificate with a single Modbus Role extension of the wrong + // type: the role should be rejected + x509Cert = &x509.Certificate{ + Extensions: []pkix.Extension{ + { + Id: modbusRoleOID, + Value: []byte{ + 0x13, 0x04, 0x66, 0x77, 0x67, 0x78, + // ^ ASN1:PrintableString + // ^ length + // ^ 4-byte string 'fwgx' + }, + }, + }, + } + + role = ms.extractRole(x509Cert) + if role != "" { + t.Errorf("role should have been empty, got: '%s'", role) + } + + // build a certificate with a single, short Modbus Role extension: the role + // should be rejected + x509Cert = &x509.Certificate{ + Extensions: []pkix.Extension{ + { + Id: modbusRoleOID, + Value: []byte{ + 0x0c, + // ^ ASN1:UTF8String + // ^ missing length + payload bytes + }, + }, + }, + } + + role = ms.extractRole(x509Cert) + if role != "" { + t.Errorf("role should have been empty, got: '%s'", role) + } + + // build a certificate with one bad Modbus Role extension (short) and one + // valid: they should both be rejected + x509Cert = &x509.Certificate{ + Extensions: []pkix.Extension{ + { + Id: modbusRoleOID, + Value: []byte{ + 0x0c, + // ^ ASN1:UTF8String + // ^ missing length + payload bytes + }, + }, + { + Id: modbusRoleOID, + Value: []byte{ + 0x0c, 0x02, 0x66, 0x67, + // ^ ASN1:UTF8String + // ^ length + // ^ 2-byte string 'fwwf' + }, + }, + }, + } + + role = ms.extractRole(x509Cert) + if role != "" { + t.Errorf("role should have been empty, got: '%s'", role) + } + + // build a certificate with a single, valid Modbus Role extension: it should be + // accepted + x509Cert = &x509.Certificate{ + Extensions: []pkix.Extension{ + { + Id: modbusRoleOID, + Value: []byte{ + 0x0c, 0x04, 0x66, 0x77, 0x67, 0x78, + // ^ ASN1:UTF8String + // ^ length + // ^ 4-byte string 'fwgx' + }, + }, + }, + } + + role = ms.extractRole(x509Cert) + if role != "fwgx" { + t.Errorf("role should have been 'fwgx', got: '%s'", role) + } + + return +} diff --git a/tcp_transport.go b/tcp_transport.go new file mode 100644 index 0000000..b117bb1 --- /dev/null +++ b/tcp_transport.go @@ -0,0 +1,205 @@ +package modbus + +import ( + "fmt" + "io" + "log" + "net" + "time" +) + +const ( + maxTCPFrameLength int = 260 + mbapHeaderLength int = 7 +) + +type tcpTransport struct { + logger *logger + socket net.Conn + timeout time.Duration + lastTxnId uint16 +} + +// Returns a new TCP transport. +func newTCPTransport(socket net.Conn, timeout time.Duration, customLogger *log.Logger) (tt *tcpTransport) { + tt = &tcpTransport{ + socket: socket, + timeout: timeout, + logger: newLogger(fmt.Sprintf("tcp-transport(%s)", socket.RemoteAddr()), customLogger), + } + + return +} + +// Closes the underlying tcp socket. +func (tt *tcpTransport) Close() (err error) { + err = tt.socket.Close() + + return +} + +// Runs a request across the socket and returns a response. +func (tt *tcpTransport) ExecuteRequest(req *pdu) (res *pdu, err error) { + // set an i/o deadline on the socket (read and write) + err = tt.socket.SetDeadline(time.Now().Add(tt.timeout)) + if err != nil { + return + } + + // increase the transaction ID counter + tt.lastTxnId++ + + _, err = tt.socket.Write(tt.assembleMBAPFrame(tt.lastTxnId, req)) + if err != nil { + return + } + + res, err = tt.readResponse() + + return +} + +// Reads a request from the socket. +func (tt *tcpTransport) ReadRequest() (req *pdu, err error) { + var txnId uint16 + + // set an i/o deadline on the socket (read and write) + err = tt.socket.SetDeadline(time.Now().Add(tt.timeout)) + if err != nil { + return + } + + req, txnId, err = tt.readMBAPFrame() + if err != nil { + return + } + + // store the incoming transaction id + tt.lastTxnId = txnId + + return +} + +// Writes a response to the socket. +func (tt *tcpTransport) WriteResponse(res *pdu) (err error) { + _, err = tt.socket.Write(tt.assembleMBAPFrame(tt.lastTxnId, res)) + if err != nil { + return + } + + return +} + +// Reads as many MBAP+modbus frames as necessary until either the response +// matching tt.lastTxnId is received or an error occurs. +func (tt *tcpTransport) readResponse() (res *pdu, err error) { + var txnId uint16 + + for { + // grab a frame + res, txnId, err = tt.readMBAPFrame() + + // ignore unknown protocol identifiers + if err == ErrUnknownProtocolId { + continue + } + + // abort on any other erorr + if err != nil { + return + } + + // ignore unknown transaction identifiers + if tt.lastTxnId != txnId { + tt.logger.Warningf("received unexpected transaction id " + + "(expected 0x%04x, received 0x%04x)", + tt.lastTxnId, txnId) + continue + } + + break + } + + return +} + +// Reads an entire frame (MBAP header + modbus PDU) from the socket. +func (tt *tcpTransport) readMBAPFrame() (p *pdu, txnId uint16, err error) { + var rxbuf []byte + var bytesNeeded int + var protocolId uint16 + var unitId uint8 + + // read the MBAP header + rxbuf = make([]byte, mbapHeaderLength) + _, err = io.ReadFull(tt.socket, rxbuf) + if err != nil { + return + } + + // decode the transaction identifier + txnId = bytesToUint16(BIG_ENDIAN, rxbuf[0:2]) + // decode the protocol identifier + protocolId = bytesToUint16(BIG_ENDIAN, rxbuf[2:4]) + // store the source unit id + unitId = rxbuf[6] + + // determine how many more bytes we need to read + bytesNeeded = int(bytesToUint16(BIG_ENDIAN, rxbuf[4:6])) + + // the byte count includes the unit ID field, which we already have + bytesNeeded-- + + // never read more than the max allowed frame length + if bytesNeeded + mbapHeaderLength > maxTCPFrameLength { + err = ErrProtocolError + return + } + + // an MBAP length of 0 is illegal + if bytesNeeded <= 0 { + err = ErrProtocolError + return + } + + // read the PDU + rxbuf = make([]byte, bytesNeeded) + _, err = io.ReadFull(tt.socket, rxbuf) + if err != nil { + return + } + + // validate the protocol identifier + if protocolId != 0x0000 { + err = ErrUnknownProtocolId + tt.logger.Warningf("received unexpected protocol id 0x%04x", protocolId) + return + } + + // store unit id, function code and payload in the PDU object + p = &pdu{ + unitId: unitId, + functionCode: rxbuf[0], + payload: rxbuf[1:], + } + + return +} + +// Turns a PDU into an MBAP frame (MBAP header + PDU) and returns it as bytes. +func (tt *tcpTransport) assembleMBAPFrame(txnId uint16, p *pdu) (payload []byte) { + // transaction identifier + payload = uint16ToBytes(BIG_ENDIAN, txnId) + // protocol identifier (always 0x0000) + payload = append(payload, 0x00, 0x00) + // length (covers unit identifier + function code + payload fields) + payload = append(payload, uint16ToBytes(BIG_ENDIAN, uint16(2 + len(p.payload)))...) + // unit identifier + payload = append(payload, p.unitId) + // function code + payload = append(payload, p.functionCode) + // payload + payload = append(payload, p.payload...) + + return +} diff --git a/tcp_transport_test.go b/tcp_transport_test.go new file mode 100644 index 0000000..39ac770 --- /dev/null +++ b/tcp_transport_test.go @@ -0,0 +1,373 @@ +package modbus + +import ( + "io" + "net" + "testing" + "time" +) + +func TestAssembleMBAPFrame(t *testing.T) { + var tt *tcpTransport + var frame []byte + + tt = &tcpTransport{} + + frame = tt.assembleMBAPFrame(0x9219, &pdu{ + unitId: 0x33, + functionCode: 0x11, + payload: []byte{0x22, 0x33, 0x44, 0x55}, + }) + // expect 7 bytes of MBAP header + 1 bytes of function code + 4 bytes of payload + if len(frame) != 12 { + t.Errorf("expected 12 bytes, got %v", len(frame)) + } + for i, b := range []byte{ + 0x92, 0x19, // transaction identifier (big endian) + 0x00, 0x00, // protocol identifier + 0x00, 0x06, // length (big endian) + 0x33, 0x11, // unit id and function code + 0x22, 0x33, // payload + 0x44, 0x55, // payload + } { + if frame[i] != b { + t.Errorf("expected 0x%02x at position %v, got 0x%02x", b, i, frame[i]) + } + } + + frame = tt.assembleMBAPFrame(0x921a, &pdu{ + unitId: 0x31, + functionCode: 0x06, + payload: []byte{0x12, 0x34}, + }) + // expect 7 bytes of MBAP header + 1 bytes of function code + 2 bytes of payload + if len(frame) != 10 { + t.Errorf("expected 10 bytes, got %v", len(frame)) + } + for i, b := range []byte{ + 0x92, 0x1a, // transaction identifier (big endian) + 0x00, 0x00, // protocol identifier + 0x00, 0x04, // length (big endian) + 0x31, 0x06, // unit id and function code + 0x12, 0x34, // payload + } { + if frame[i] != b { + t.Errorf("expected 0x%02x at position %v, got 0x%02x", b, i, frame[i]) + } + } + + return +} + +func TestTCPTransportReadResponse(t *testing.T) { + var tt *tcpTransport + var p1, p2 net.Conn + var txchan chan []byte + var err error + var res *pdu + + txchan = make(chan []byte, 2) + p1, p2 = net.Pipe() + go feedTestPipe(t, txchan, p1) + + + tt = newTCPTransport(p2, 10 * time.Millisecond, nil) + tt.lastTxnId = 0x9218 + + // read a valid response + txchan <- []byte{ + 0x92, 0x18, // transaction identifier (big endian) + 0x00, 0x00, // protocol identifier + 0x00, 0x04, // length (big endian) + 0x31, 0x06, // unit id and function code + 0x12, 0x34, // payload + } + res, err = tt.readResponse() + if err != nil { + t.Errorf("readResponse() should have succeeded, got %v", err) + } + if res.unitId != 0x31 { + t.Errorf("expected 0x31 as unit id, got 0x%02x", res.unitId) + } + if res.functionCode != 0x06 { + t.Errorf("expected 0x06 as function code, got 0x%02x", res.functionCode) + } + if len(res.payload) != 2 { + t.Errorf("expected a length of 2, got %v", len(res.payload)) + } + if res.payload[0] != 0x12 || res.payload[1] != 0x34 { + t.Errorf("expected {0x12, 0x34} as payload, got {0x%02x, 0x%02x}", + res.payload[0], res.payload[1]) + } + + // read a frame with an unexpected transaction id followed by a frame with a + // matching transaction id: the first frame should be silently skipped + txchan <- []byte{ + 0x92, 0x19, // transaction identifier (big endian) + 0x00, 0x00, // protocol identifier + 0x00, 0x04, // length (big endian) + 0x31, 0x06, // unit id and function code + 0x12, 0x34, // payload + } + txchan <- []byte{ + 0x92, 0x18, // transaction identifier (big endian) + 0x00, 0x00, // protocol identifier + 0x00, 0x04, // length (big endian) + 0x39, 0x02, // unit id and function code + 0x10, 0x01, // payload + } + res, err = tt.readResponse() + if err != nil { + t.Errorf("readResponse() should have succeeded, got %v", err) + } + if res.unitId != 0x39 { + t.Errorf("expected 0x39 as unit id, got 0x%02x", res.unitId) + } + if res.functionCode != 0x02 { + t.Errorf("expected 0x02 as function code, got 0x%02x", res.functionCode) + } + if len(res.payload) != 2 { + t.Errorf("expected a length of 2, got %v", len(res.payload)) + } + if res.payload[0] != 0x10 || res.payload[1] != 0x01 { + t.Errorf("expected {0x10, 0x01 as payload, got {0x%02x, 0x%02x}", + res.payload[0], res.payload[1]) + } + + // read a frame with an illegal length, preceded by a frame with an unexpected + // protocol ID. While the first frame should be skipped without error, + // the second should yield an ErrProtocolError. + txchan <- []byte{ + 0x92, 0x18, // transaction identifier (big endian) + 0x00, 0x01, // protocol identifier + 0x00, 0x04, // length (big endian) + 0x31, 0x06, // unit id and function code + 0x12, 0x34, // payload + } + txchan <- []byte{ + 0x92, 0x18, // transaction identifier (big endian) + 0x00, 0x00, // protocol identifier + 0x00, 0x01, // length (big endian) + 0x31, // unit id + } + res, err = tt.readResponse() + if err != ErrProtocolError { + t.Errorf("readResponse() should have returned ErrProtocolError, got %v", err) + } + + // read a valid frame again + txchan <- []byte{ + 0x92, 0x18, // transaction identifier (big endian) + 0x00, 0x00, // protocol identifier + 0x00, 0x0a, // length (big endian) + 0x31, 0x32, // unit id and function code + 0x44, 0x55, // payload + 0x66, 0x77, // payload + 0x88, 0x99, // payload + 0xaa, 0xbb, // payload + } + res, err = tt.readResponse() + if err != nil { + t.Errorf("readResponse() should have succeeded, got %v", err) + } + if res.unitId != 0x31 { + t.Errorf("expected 0x31 as unit id, got 0x%02x", res.unitId) + } + if res.functionCode != 0x32 { + t.Errorf("expected 0x32 as response code, got 0x%02x", res.functionCode) + } + if len(res.payload) != 8 { + t.Errorf("expected a length of 8, got %v", len(res.payload)) + } + for i, b := range []byte{ + 0x44, 0x55, + 0x66, 0x77, + 0x88, 0x99, + 0xaa, 0xbb, + } { + if res.payload[i] != b { + t.Errorf("expected 0x%02x at position %v, got 0x%02x", + b, i, res.payload[i]) + } + } + + // read a huge frame + txchan <- []byte{ + 0x92, 0x18, // transaction identifier (big endian) + 0x00, 0x00, // protocol identifier + 0x10, 0x0a, // length (big endian) + 0x31, // unit id + } + res, err = tt.readResponse() + if err != ErrProtocolError { + t.Errorf("readResponse() should have returned ErrProtocolError, got %v", err) + } + + p1.Close() + p2.Close() + + return +} + +func TestTCPTransportReadRequest(t *testing.T) { + var tt *tcpTransport + var p1, p2 net.Conn + var txchan chan []byte + var err error + var req *pdu + + txchan = make(chan []byte, 2) + p1, p2 = net.Pipe() + go feedTestPipe(t, txchan, p1) + + + tt = newTCPTransport(p2, 10 * time.Millisecond, nil) + tt.lastTxnId = 0x0a00 + + // push three frames in a row: + // - the first with an unknown protocol ID + txchan <- []byte{ + 0x92, 0x18, // transaction identifier (big endian) + 0x00, 0x01, // protocol identifier + 0x00, 0x04, // length (big endian) + 0x31, 0x06, // unit id and function code + 0x12, 0x34, // payload + } + // - the second with an illegal length + txchan <- []byte{ + 0x92, 0x18, // transaction identifier (big endian) + 0x00, 0x00, // protocol identifier + 0x00, 0x01, // length (big endian) + 0x31, // unit id + } + // - the thid with a valid request + txchan <- []byte{ + 0x92, 0x18, // transaction identifier (big endian) + 0x00, 0x00, // protocol identifier + 0x00, 0x0a, // length (big endian) + 0xfa, 0x04, // unit id and function code + 0x44, 0x55, // payload + 0x66, 0x77, // payload + 0x88, 0x99, // payload + 0xaa, 0xbb, // payload + } + + // read the first frame + req, err = tt.ReadRequest() + if req != nil || err != ErrUnknownProtocolId { + t.Errorf("ReadRequest() should have returned {nil, ErrUnknownProtocolId}, got {%v, %v}", req, err) + } + if tt.lastTxnId != 0x0a00 { + t.Errorf("tt.lastTxnId should have been 0x0a00, saw 0x%02x", tt.lastTxnId) + } + + // read the second frame + req, err = tt.ReadRequest() + if req != nil || err != ErrProtocolError { + t.Errorf("ReadRequest() should have returned {nil, ErrProtocolError}, got {%v, %v}", req, err) + } + if tt.lastTxnId != 0x0a00 { + t.Errorf("tt.lastTxnId should have been 0x0a00, saw 0x%02x", tt.lastTxnId) + } + + // read the third frame + req, err = tt.ReadRequest() + if err != nil { + t.Errorf("ReadRequest() should have succeeded, got %v", err) + } + if req == nil { + t.Errorf("ReadREsponse() should have returned a non-nil request") + } + if req.unitId != 0xfa { + t.Errorf("expected 0xfa as unit id, got 0x%02x", req.unitId) + } + if req.functionCode != 0x04 { + t.Errorf("expected 0x04 as response code, got 0x%02x", req.functionCode) + } + if len(req.payload) != 8 { + t.Errorf("expected a length of 8, got %v", len(req.payload)) + } + for i, b := range []byte{ + 0x44, 0x55, + 0x66, 0x77, + 0x88, 0x99, + 0xaa, 0xbb, + } { + if req.payload[i] != b { + t.Errorf("expected 0x%02x at position %v, got 0x%02x", + b, i, req.payload[i]) + } + } + if tt.lastTxnId != 0x9218 { + t.Errorf("tt.lastTxnId should have been 0x0a00, saw 0x%02x", tt.lastTxnId) + } + + return +} + +func TestTCPTransportWriteResponse(t *testing.T) { + var tt *tcpTransport + var p1, p2 net.Conn + var done chan bool + var err error + + done = make(chan bool, 0) + p1, p2 = net.Pipe() + go func(t *testing.T, pipe net.Conn, done chan bool) { + var err error + var rxbuf []byte + var expected []byte + + expected = []byte{ + 0xc0, 0x1f, // transaction identifier (big endian) + 0x00, 0x00, // protocol identifier + 0x00, 0x0b, // length (big endian) + 0x17, 0x06, // unit id and function code + 0x44, 0x55, // payload + 0x66, 0x77, // payload + 0x88, 0x99, // payload + 0xaa, 0xbb, // payload + 0xf4, // payload + } + + rxbuf = make([]byte, len(expected)) + _, err = io.ReadFull(pipe, rxbuf) + if err != nil { + t.Errorf("failed to read frame: %v", err) + } + + for i, b := range expected { + if rxbuf[i] != b { + t.Errorf("expected 0x%02x at position %v, got 0x%02x", + b, i, rxbuf[i]) + } + } + + done<- true + return + }(t, p2, done) + + + tt = newTCPTransport(p1, 10 * time.Millisecond, nil) + tt.lastTxnId = 0xc01f + + err = tt.WriteResponse(&pdu{ + unitId: 0x17, + functionCode: 0x06, + payload: []byte{ + 0x44, 0x55, // payload + 0x66, 0x77, // payload + 0x88, 0x99, // payload + 0xaa, 0xbb, // payload + 0xf4, // payload + }, + }) + if err != nil { + t.Errorf("WriteResponse() should have succeeded, got %v", err) + } + + // wait for the checker goroutine to return + <-done + + return +} diff --git a/tls_utils.go b/tls_utils.go new file mode 100644 index 0000000..0539246 --- /dev/null +++ b/tls_utils.go @@ -0,0 +1,116 @@ +package modbus + +import ( + "crypto/x509" + "fmt" + "io/ioutil" + "net" + "os" + "time" +) + +// LoadCertPool loads a certificate store from a file into a CertPool object. +func LoadCertPool(filePath string) (cp *x509.CertPool, err error) { + var buf []byte + + // read the entire cert store, which may contain zero, one + // or more certificates + buf, err = ioutil.ReadFile(filePath) + if err != nil { + return + } + + if len(buf) == 0 { + err = fmt.Errorf("%v: empty file", filePath) + return + } + + // add these certs to the pool + cp = x509.NewCertPool() + cp.AppendCertsFromPEM(buf) + + // let the caller know if no usable certificate was found + if len(cp.Subjects()) == 0 { + err = fmt.Errorf("%v: no certificate found", filePath) + return + } + + return +} + +// tlsSockWrapper wraps a TLS socket to work around odd error handling in +// TLSConn on internal connection state corruption. +// tlsSockWrapper implements the net.Conn interface to allow its +// use by the modbus TCP transport. +type tlsSockWrapper struct { + sock net.Conn +} + +func newTLSSockWrapper(sock net.Conn) (tsw *tlsSockWrapper) { + tsw = &tlsSockWrapper{ + sock: sock, + } + + return +} + +func (tsw *tlsSockWrapper) Read(buf []byte) (rlen int, err error) { + rlen, err = tsw.sock.Read(buf) + + return +} + +func (tsw *tlsSockWrapper) Write(buf []byte) (wlen int, err error) { + wlen, err = tsw.sock.Write(buf) + + // since write timeouts corrupt the internal state of TLS sockets, + // any subsequent read/write operation will fail and return the same write + // timeout error (see https://pkg.go.dev/crypto/tls#Conn.SetWriteDeadline). + // this isn't all that helpful to clients, which may be tricked into + // retrying forever, treating timeout errors as transient. + // to avoid this, close the TLS socket after the first write timeout. + // this ensures that clients 1) get a timeout error on the first write timeout + // and 2) get an ErrNetClosing "use of closed network connection" on subsequent + // operations. + if err != nil && os.IsTimeout(err) { + tsw.sock.Close() + } + + return +} + +func (tsw *tlsSockWrapper) Close() (err error) { + err = tsw.sock.Close() + + return +} + +func (tsw *tlsSockWrapper) SetDeadline(deadline time.Time) (err error) { + err = tsw.sock.SetDeadline(deadline) + + return +} + +func (tsw *tlsSockWrapper) SetReadDeadline(deadline time.Time) (err error) { + err = tsw.sock.SetReadDeadline(deadline) + + return +} + +func (tsw *tlsSockWrapper) SetWriteDeadline(deadline time.Time) (err error) { + err = tsw.sock.SetWriteDeadline(deadline) + + return +} + +func (tsw *tlsSockWrapper) LocalAddr() (addr net.Addr) { + addr = tsw.sock.LocalAddr() + + return +} + +func (tsw *tlsSockWrapper) RemoteAddr() (addr net.Addr) { + addr = tsw.sock.RemoteAddr() + + return +} diff --git a/tls_utils_test.go b/tls_utils_test.go new file mode 100644 index 0000000..34e6931 --- /dev/null +++ b/tls_utils_test.go @@ -0,0 +1,151 @@ +package modbus + +import ( + "crypto/x509" + "io/ioutil" + "os" + "testing" +) + +// random certs from /etc/ssl/certs +const validCerts = `-----BEGIN CERTIFICATE----- +MIIFcDCCA1igAwIBAgIEAJiWjTANBgkqhkiG9w0BAQsFADBYMQswCQYDVQQGEwJO +TDEeMBwGA1UECgwVU3RhYXQgZGVyIE5lZGVybGFuZGVuMSkwJwYDVQQDDCBTdGFh +dCBkZXIgTmVkZXJsYW5kZW4gRVYgUm9vdCBDQTAeFw0xMDEyMDgxMTE5MjlaFw0y +MjEyMDgxMTEwMjhaMFgxCzAJBgNVBAYTAk5MMR4wHAYDVQQKDBVTdGFhdCBkZXIg +TmVkZXJsYW5kZW4xKTAnBgNVBAMMIFN0YWF0IGRlciBOZWRlcmxhbmRlbiBFViBS +b290IENBMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA48d+ifkkSzrS +M4M1LGns3Amk41GoJSt5uAg94JG6hIXGhaTK5skuU6TJJB79VWZxXSzFYGgEt9nC +UiY4iKTWO0Cmws0/zZiTs1QUWJZV1VD+hq2kY39ch/aO5ieSZxeSAgMs3NZmdO3d +Z//BYY1jTw+bbRcwJu+r0h8QoPnFfxZpgQNH7R5ojXKhTbImxrpsX23Wr9GxE46p +rfNeaXUmGD5BKyF/7otdBwadQ8QpCiv8Kj6GyzyDOvnJDdrFmeK8eEEzduG/L13l +pJhQDBXd4Pqcfzho0LKmeqfRMb1+ilgnQ7O6M5HTp5gVXJrm0w912fxBmJc+qiXb +j5IusHsMX/FjqTf5m3VpTCgmJdrV8hJwRVXj33NeN/UhbJCONVrJ0yPr08C+eKxC +KFhmpUZtcALXEPlLVPxdhkqHz3/KRawRWrUgUY0viEeXOcDPusBCAUCZSCELa6fS +/ZbV0b5GnUngC6agIk440ME8MLxwjyx1zNDFjFE7PZQIZCZhfbnDZY8UnCHQqv0X +cgOPvZuM5l5Tnrmd74K74bzickFbIZTTRTeU0d8JOV3nI6qaHcptqAqGhYqCvkIH +1vI4gnPah1vlPNOePqc7nvQDs/nxfRN0Av+7oeX6AHkcpmZBiFxgV6YuCcS6/ZrP +px9Aw7vMWgpVSzs4dlG4Y4uElBbmVvMCAwEAAaNCMEAwDwYDVR0TAQH/BAUwAwEB +/zAOBgNVHQ8BAf8EBAMCAQYwHQYDVR0OBBYEFP6rAJCYniT8qcwaivsnuL8wbqg7 +MA0GCSqGSIb3DQEBCwUAA4ICAQDPdyxuVr5Os7aEAJSrR8kN0nbHhp8dB9O2tLsI +eK9p0gtJ3jPFrK3CiAJ9Brc1AsFgyb/E6JTe1NOpEyVa/m6irn0F3H3zbPB+po3u +2dfOWBfoqSmuc0iH55vKbimhZF8ZE/euBhD/UcabTVUlT5OZEAFTdfETzsemQUHS +v4ilf0X8rLiltTMMgsT7B/Zq5SWEXwbKwYY5EdtYzXc7LMJMD16a4/CrPmEbUCTC +wPTxGfARKbalGAKb12NMcIxHowNDXLldRqANb/9Zjr7dn3LDWyvfjFvO5QxGbJKy +CqNMVEIYFRIYvdr8unRu/8G2oGTYqV9Vrp9canaW2HNnh/tNf1zuacpzEPuKqf2e +vTY4SUmH9A4U8OmHuD+nT3pajnnUk+S7aFKErGzp85hwVXIy+TSrK0m1zSBi5Dp6 +Z2Orltxtrpfs/J92VoguZs9btsmksNcFuuEnL5O7Jiqik7Ab846+HUCjuTaPPoIa +Gl6I6lD4WeKDRikL40Rc4ZW2aZCaFG+XroHPaO+Zmr615+F/+PoTRxZMzG0IQOeL +eG9QgkRQP2YGiqtDhFZKDyAthg710tvSeopLzaXoTvFeJiUBWSOgftL2fiFX1ye8 +FVdMpEbB4IMeDExNH08GGeL5qPQ6gqGyeUN51q1veieQA6TqJIc/2b3Z6fJfUEkc +7uzXLg== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDnzCCAoegAwIBAgIBJjANBgkqhkiG9w0BAQUFADBxMQswCQYDVQQGEwJERTEc +MBoGA1UEChMTRGV1dHNjaGUgVGVsZWtvbSBBRzEfMB0GA1UECxMWVC1UZWxlU2Vj +IFRydXN0IENlbnRlcjEjMCEGA1UEAxMaRGV1dHNjaGUgVGVsZWtvbSBSb290IENB +IDIwHhcNOTkwNzA5MTIxMTAwWhcNMTkwNzA5MjM1OTAwWjBxMQswCQYDVQQGEwJE +RTEcMBoGA1UEChMTRGV1dHNjaGUgVGVsZWtvbSBBRzEfMB0GA1UECxMWVC1UZWxl +U2VjIFRydXN0IENlbnRlcjEjMCEGA1UEAxMaRGV1dHNjaGUgVGVsZWtvbSBSb290 +IENBIDIwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCrC6M14IspFLEU +ha88EOQ5bzVdSq7d6mGNlUn0b2SjGmBmpKlAIoTZ1KXleJMOaAGtuU1cOs7TuKhC +QN/Po7qCWWqSG6wcmtoIKyUn+WkjR/Hg6yx6m/UTAtB+NHzCnjwAWav12gz1Mjwr +rFDa1sPeg5TKqAyZMg4ISFZbavva4VhYAUlfckE8FQYBjl2tqriTtM2e66foai1S +NNs671x1Udrb8zH57nGYMsRUFUQM+ZtV7a3fGAigo4aKSe5TBY8ZTNXeWHmb0moc +QqvF1afPaA+W5OFhmHZhyJF81j4A4pFQh+GdCuatl9Idxjp9y7zaAzTVjlsB9WoH +txa2bkp/AgMBAAGjQjBAMB0GA1UdDgQWBBQxw3kbuvVT1xfgiXotF2wKsyudMzAP +BgNVHRMECDAGAQH/AgEFMA4GA1UdDwEB/wQEAwIBBjANBgkqhkiG9w0BAQUFAAOC +AQEAlGRZrTlk5ynrE/5aw4sTV8gEJPB0d8Bg42f76Ymmg7+Wgnxu1MM9756Abrsp +tJh6sTtU6zkXR34ajgv8HzFZMQSyzhfzLMdiNlXiItiJVbSYSKpk+tYcNthEeFpa +IzpXl/V6ME+un2pMSyuOoAPjPuCp1NJ70rOo4nI8rZ7/gFnkm0W09juwzTkZmDLl +6iFhkOQxIY40sfcvNUqFENrnijchvllj4PKFiDFT1FQUhXB59C4Gdyd1Lx+4ivn+ +xbrYNuSD7Odlt79jWvNGr4GUN9RBjNYj1h7P9WgbRGOiWrqnNVmh5XAFmw4jV5mU +Cm26OWMohpLzGITY+9HPBVZkVw== +-----END CERTIFICATE----- +` + +func TestLoadCertPool(t *testing.T) { + var err error + var cp *x509.CertPool + var fd *os.File + var path string + + // attemp to load a non-existent file: should fail + cp, err = LoadCertPool("non/existent/path/to/store") + if err == nil { + t.Errorf("LoadCertPool() should have failed") + } + + // create an empty file and attempt to load it: should fail + fd, err = ioutil.TempFile("", "modbus_tls_utils_test") + if err != nil { + t.Errorf("failed to create temp file: %v", err) + return + } + path = fd.Name() + + defer os.Remove(path) + err = fd.Close() + if err != nil { + t.Errorf("failed to close temp file: %v", err) + return + } + + cp, err = LoadCertPool(path) + if err == nil { + t.Errorf("LoadCertPool() should have failed") + } + + // put garbage into a file and attempt to load it: should fail + fd, err = ioutil.TempFile("", "modbus_tls_utils_test") + if err != nil { + t.Errorf("failed to create temp file: %v", err) + } + path = fd.Name() + + defer os.Remove(path) + _, err = fd.Write([]byte("somejunk")) + if err != nil { + t.Errorf("failed to write to temp file: %v", err) + } + err = fd.Close() + if err != nil { + t.Errorf("failed to close temp file: %v", err) + return + } + + cp, err = LoadCertPool(path) + if err == nil { + t.Errorf("LoadCertPool() should have failed") + } + + // now write two certs to a file and try to load it: should succeed + fd, err = ioutil.TempFile("", "modbus_tls_utils_test") + if err != nil { + t.Errorf("failed to create temp file: %v", err) + } + path = fd.Name() + + defer os.Remove(path) + _, err = fd.Write([]byte(validCerts)) + if err != nil { + t.Errorf("failed to write to temp file: %v", err) + } + err = fd.Close() + if err != nil { + t.Errorf("failed to close temp file: %v", err) + return + } + + cp, err = LoadCertPool(path) + if err != nil { + t.Errorf("LoadCertPool() should have succeeded, got: %v", err) + } + + // expect two certs in the cert pool + if len(cp.Subjects()) != 2 { + t.Errorf("expected 2 certs in the pool, saw: %v", len(cp.Subjects())) + } + + return +} diff --git a/transport.go b/transport.go new file mode 100644 index 0000000..3e02013 --- /dev/null +++ b/transport.go @@ -0,0 +1,18 @@ +package modbus + +type transportType uint +const ( + modbusRTU transportType = 1 + modbusRTUOverTCP transportType = 2 + modbusRTUOverUDP transportType = 3 + modbusTCP transportType = 4 + modbusTCPOverTLS transportType = 5 + modbusTCPOverUDP transportType = 6 +) + +type transport interface { + Close() (error) + ExecuteRequest(*pdu) (*pdu, error) + ReadRequest() (*pdu, error) + WriteResponse(*pdu) (error) +} diff --git a/udp.go b/udp.go new file mode 100644 index 0000000..8828467 --- /dev/null +++ b/udp.go @@ -0,0 +1,102 @@ +package modbus + +import ( + "net" + "time" +) + +// udpSockWrapper wraps a net.UDPConn (UDP socket) to +// allow transports to consume data off the network socket on +// a byte per byte basis rather than datagram by datagram. +type udpSockWrapper struct { + leftoverCount int + rxbuf []byte + sock *net.UDPConn +} + +func newUDPSockWrapper(sock net.Conn) (usw *udpSockWrapper) { + usw = &udpSockWrapper{ + rxbuf: make([]byte, maxTCPFrameLength), + sock: sock.(*net.UDPConn), + } + + return +} + +func (usw *udpSockWrapper) Read(buf []byte) (rlen int, err error) { + var copied int + + if usw.leftoverCount > 0 { + // if we're holding onto any bytes from a previous datagram, + // use them to satisfy the read (potentially partially) + copied = copy(buf, usw.rxbuf[0:usw.leftoverCount]) + + if usw.leftoverCount > copied { + // move any leftover bytes to the beginning of the buffer + copy(usw.rxbuf, usw.rxbuf[copied:usw.leftoverCount]) + } + // make a note of how many leftover bytes we have in the buffer + usw.leftoverCount -= copied + } else { + // read up to maxTCPFrameLength bytes from the socket + rlen, err = usw.sock.Read(usw.rxbuf) + if err != nil { + return + } + // copy as many bytes as possible to satisfy the read + copied = copy(buf, usw.rxbuf[0:rlen]) + + if rlen > copied { + // move any leftover bytes to the beginning of the buffer + copy(usw.rxbuf, usw.rxbuf[copied:rlen]) + } + // make a note of how many leftover bytes we have in the buffer + usw.leftoverCount = rlen - copied + } + + rlen = copied + + return +} + +func (usw *udpSockWrapper) Close() (err error) { + err = usw.sock.Close() + + return +} + +func (usw *udpSockWrapper) Write(buf []byte) (wlen int, err error) { + wlen, err = usw.sock.Write(buf) + + return +} + +func (usw *udpSockWrapper) SetDeadline(deadline time.Time) (err error) { + err = usw.sock.SetDeadline(deadline) + + return +} + +func (usw *udpSockWrapper) SetReadDeadline(deadline time.Time) (err error) { + err = usw.sock.SetReadDeadline(deadline) + + return +} + +func (usw *udpSockWrapper) SetWriteDeadline(deadline time.Time) (err error) { + err = usw.sock.SetWriteDeadline(deadline) + + return +} + +func (usw *udpSockWrapper) LocalAddr() (addr net.Addr) { + addr = usw.sock.LocalAddr() + + return +} + +func (usw *udpSockWrapper) RemoteAddr() (addr net.Addr) { + addr = usw.sock.RemoteAddr() + + return +} diff --git a/udp_test.go b/udp_test.go new file mode 100644 index 0000000..4f738f5 --- /dev/null +++ b/udp_test.go @@ -0,0 +1,166 @@ +package modbus + +import ( + "net" + "os" + "testing" + "time" +) + +func TestUDPSockWrapper(t *testing.T) { + var err error + var usw *udpSockWrapper + var sock1 *net.UDPConn + var sock2 *net.UDPConn + var addr *net.UDPAddr + var txchan chan []byte + var rxbuf []byte + var count int + + addr, err = net.ResolveUDPAddr("udp", "localhost:5502") + if err != nil { + t.Errorf("failed to resolve udp address: %v", err) + return + } + + txchan = make(chan []byte, 4) + // get a pair of UDP sockets ready to talk to each other + sock1, err = net.ListenUDP("udp", addr) + if err != nil { + t.Errorf("failed to listen on udp socket: %v", err) + return + } + err = sock1.SetReadDeadline(time.Now().Add(1 * time.Second)) + if err != nil { + t.Errorf("failed to set deadline on udp socket: %v", err) + return + } + + sock2, err = net.DialUDP("udp", nil, addr) + if err != nil { + t.Errorf("failed to open udp socket: %v", err) + return + } + // the feedTestPipe goroutine will forward any slice of bytes + // pushed into txchan over UDP to our test UDP sock wrapper object + go feedTestPipe(t, txchan, sock2) + + usw = newUDPSockWrapper(sock1) + // push a valid RTU response (illegal data address) to the test pipe + txchan <- []byte{ + 0x31, 0x82, // unit id and response code + 0x02, // exception code + 0xc1, 0x6e, // CRC + } + // then push random junk + txchan <-[]byte{ + 0xaa, 0xbb, 0xcc, + } + // then some more + txchan <-[]byte{ + 0xdd, 0xee, + } + + // attempt to read 3 bytes: we should get them as the first datagram + // is 5 bytes long + rxbuf = make([]byte, 3) + count, err = usw.Read(rxbuf) + if err != nil { + t.Errorf("usw.Read() should have succeeded, got: %v", err) + } + if count != 3 { + t.Errorf("expected 3 bytes, got: %v", count) + } + for idx, val := range []byte{ + 0x31, 0x82, 0x02, + } { + if rxbuf[idx] != val { + t.Errorf("expected 0x%02x at pos %v, got: 0x%02x", + val, idx, rxbuf[idx]) + } + } + + // attempt to read 1 byte: we should get the 4th byte of the + // first datagram, of which we've been holding on to bytes #4 and 5 + rxbuf = make([]byte, 1) + count, err = usw.Read(rxbuf) + if err != nil { + t.Errorf("usw.Read() should have succeeded, got: %v", err) + } + if count != 1 { + t.Errorf("expected 1 byte, got: %v", count) + } + if rxbuf[0] != 0xc1 { + t.Errorf("expected 0xc1 at pos 0, got: 0x%02x", rxbuf[0]) + } + + // attempt to read 5 bytes: we should get the last byte of the + // first datagram, which the udpSockWrapper object still holds in + // its buffer + rxbuf = make([]byte, 5) + count, err = usw.Read(rxbuf) + if err != nil { + t.Errorf("usw.Read() should have succeeded, got: %v", err) + } + if count != 1 { + t.Errorf("expected 1 byte, got: %v", count) + } + if rxbuf[0] != 0x6e { + t.Errorf("expected 0x6e at pos 0, got: 0x%02x", rxbuf[0]) + } + + // attempt to read 10 bytes: we should get all 3 bytes of the 2nd + // datagram + rxbuf = make([]byte, 10) + count, err = usw.Read(rxbuf) + if err != nil { + t.Errorf("usw.Read() should have succeeded, got: %v", err) + } + if count != 3 { + t.Errorf("expected 3 bytes, got: %v", count) + } + for idx, val := range []byte{ + 0xaa, 0xbb, 0xcc, + } { + if rxbuf[idx] != val { + t.Errorf("expected 0x%02x at pos %v, got: 0x%02x", + val, idx, rxbuf[idx]) + } + } + + // attempt to read 40 bytes: we should get both bytes of the 3rd + // datagram + rxbuf = make([]byte, 40) + count, err = usw.Read(rxbuf) + if err != nil { + t.Errorf("usw.Read() should have succeeded, got: %v", err) + } + if count != 2 { + t.Errorf("expected 2 bytes, got: %v", count) + } + for idx, val := range []byte{ + 0xdd, 0xee, + } { + if rxbuf[idx] != val { + t.Errorf("expected 0x%02x at pos %v, got: 0x%02x", + val, idx, rxbuf[idx]) + } + } + + // attempt to read 7 bytes: we should get a read timeout as we've + // consumed all bytes from all datagrams and no more are coming + rxbuf = make([]byte, 7) + count, err = usw.Read(rxbuf) + if !os.IsTimeout(err) { + t.Errorf("usw.Read() should have failed with a timeout error, got: %v", err) + } + if count != 0 { + t.Errorf("expected 0 bytes, got: %v", count) + } + + // cleanup + sock1.Close() + sock2.Close() + + return +}