Resetting master

This commit is contained in:
Ludovic Chenut 2024-03-08 13:50:59 +01:00
parent 4406004a0d
commit a0f6c777d8
No known key found for this signature in database
GPG Key ID: D9A59B1907F1D50C
16 changed files with 5 additions and 1771 deletions

View File

@ -61,6 +61,6 @@ jobs:
run: |
nim --version
nimble --version
nimble test
nim c examples/ping.nim
nim c examples/pong.nim
# nimble test
# nim c examples/ping.nim
# nim c examples/pong.nim

View File

@ -1,24 +0,0 @@
import chronos, stew/byteutils
import ../webrtc/udp_connection
import ../webrtc/stun/stun_connection
import ../webrtc/dtls/dtls
import ../webrtc/sctp
proc main() {.async.} =
let laddr = initTAddress("127.0.0.1:4244")
let udp = UdpConn()
udp.init(laddr)
let stun = StunConn()
stun.init(udp, laddr)
let dtls = Dtls()
dtls.init(stun, laddr)
let sctp = Sctp()
sctp.init(dtls, laddr)
let conn = await sctp.connect(initTAddress("127.0.0.1:4242"), sctpPort = 13)
while true:
await conn.write("ping".toBytes)
let msg = await conn.read()
echo "Received: ", string.fromBytes(msg.data)
await sleepAsync(1.seconds)
waitFor(main())

View File

@ -1,30 +0,0 @@
import chronos, stew/byteutils
import ../webrtc/udp_connection
import ../webrtc/stun/stun_connection
import ../webrtc/dtls/dtls
import ../webrtc/sctp
proc sendPong(conn: SctpConn) {.async.} =
var i = 0
while true:
let msg = await conn.read()
echo "Received: ", string.fromBytes(msg.data)
await conn.write(("pong " & $i).toBytes)
i.inc()
proc main() {.async.} =
let laddr = initTAddress("127.0.0.1:4242")
let udp = UdpConn()
udp.init(laddr)
let stun = StunConn()
stun.init(udp, laddr)
let dtls = Dtls()
dtls.init(stun, laddr)
let sctp = Sctp()
sctp.init(dtls, laddr)
sctp.listen(13)
while true:
let conn = await sctp.accept()
asyncSpawn conn.sendPong()
waitFor(main())

View File

@ -1,4 +0,0 @@
{.used.}
import testdatachannel
import teststun

View File

@ -1,25 +0,0 @@
import ../webrtc/datachannel
import chronos/unittest2/asynctests
import binary_serialization
suite "DataChannel encoding":
test "DataChannelOpenMessage":
let msg = @[
0x03'u8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x03, 0x00, 0x03, 0x66, 0x6f, 0x6f, 0x62, 0x61, 0x72]
check msg == Binary.encode(Binary.decode(msg, DataChannelMessage))
check Binary.decode(msg, DataChannelMessage).openMessage ==
DataChannelOpenMessage(
channelType: Reliable,
priority: 0,
reliabilityParameter: 0,
labelLength: 3,
protocolLength: 3,
label: @[102, 111, 111],
protocol: @[98, 97, 114]
)
test "DataChannelAck":
let msg = @[0x02'u8]
check msg == Binary.encode(Binary.decode(msg, DataChannelMessage))
check Binary.decode(msg, DataChannelMessage).messageType == Ack

View File

@ -1,36 +0,0 @@
import options
import ../webrtc/stun/stun
import ../webrtc/stun/stun_attributes
import ./asyncunit
suite "Stun message encoding/decoding":
test "Stun decoding":
let msg = @[ 0x00'u8, 0x01, 0x00, 0xa4, 0x21, 0x12, 0xa4, 0x42, 0x75, 0x6a, 0x58, 0x46, 0x42, 0x58, 0x4e, 0x72, 0x6a, 0x50, 0x4d, 0x2b, 0x00, 0x06, 0x00, 0x63, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2b, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2b, 0x76, 0x31, 0x2f, 0x62, 0x71, 0x36, 0x67, 0x69, 0x43, 0x75, 0x4a, 0x38, 0x6e, 0x78, 0x59, 0x46, 0x4a, 0x36, 0x43, 0x63, 0x67, 0x45, 0x59, 0x58, 0x58, 0x2f, 0x78, 0x51, 0x58, 0x56, 0x4c, 0x74, 0x39, 0x71, 0x7a, 0x3a, 0x6c, 0x69, 0x62, 0x70, 0x32, 0x70, 0x2b, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2b, 0x76, 0x31, 0x2f, 0x62, 0x71, 0x36, 0x67, 0x69, 0x43, 0x75, 0x4a, 0x38, 0x6e, 0x78, 0x59, 0x46, 0x4a, 0x36, 0x43, 0x63, 0x67, 0x45, 0x59, 0x58, 0x58, 0x2f, 0x78, 0x51, 0x58, 0x56, 0x4c, 0x74, 0x39, 0x71, 0x7a, 0x00, 0xc0, 0x57, 0x00, 0x04, 0x00, 0x00, 0x03, 0xe7, 0x80, 0x2a, 0x00, 0x08, 0x86, 0x63, 0xfd, 0x45, 0xa9, 0xe5, 0x4c, 0xdb, 0x00, 0x24, 0x00, 0x04, 0x6e, 0x00, 0x1e, 0xff, 0x00, 0x08, 0x00, 0x14, 0x16, 0xff, 0x70, 0x8d, 0x97, 0x0b, 0xd6, 0xa3, 0x5b, 0xac, 0x8f, 0x4c, 0x85, 0xe6, 0xa6, 0xac, 0xaa, 0x7a, 0x68, 0x27, 0x80, 0x28, 0x00, 0x04, 0x79, 0x5e, 0x03, 0xd8 ]
let stunmsg = StunMessage.decode(msg)
check:
stunmsg.msgType == 1
stunmsg.transactionId.len() == 12
stunmsg.attributes.len() == 6
stunmsg.attributes[0].attributeType == 6 # AttrUsername
stunmsg.attributes[^1].attributeType == 0x8028 # AttrFingerprint
test "Stun encoding":
let transactionId: array[12, byte] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
var msg = StunMessage(msgType: 0x0001'u16, transactionId: transactionId)
msg.attributes.add(ErrorCode.encode(ECUnknownAttribute))
let encoded = msg.encode()
let decoded = StunMessage.decode(encoded)
# cannot do `check msg == decoded` because encode add a Fingerprint
# attribute at the end
check:
decoded.msgType == 1
decoded.transactionId == transactionId
decoded.attributes.len() == 2
decoded.attributes[0].attributeType == 9 # AttrErrorCode
decoded.attributes[^1].attributeType == 0x8028 # AttrFingerprint
test "Error while decoding":
let msgLengthFailed = @[ 0x00'u8, 0x01, 0x00, 0xa4, 0x21, 0x12, 0xa4, 0x42, 0x75, 0x6a, 0x58, 0x46, 0x42, 0x58, 0x4e, 0x72, 0x6a, 0x50, 0x4d ]
expect AssertionDefect: discard StunMessage.decode(msgLengthFailed)
let msgAttrFailed = @[ 0x00'u8, 0x01, 0x00, 0x08, 0x21, 0x12, 0xa4, 0x42, 0x75, 0x6a, 0x58, 0x46, 0x42, 0x58, 0x4e, 0x72, 0x6a, 0x50, 0x4d, 0x2b, 0x28, 0x00, 0x05, 0x79, 0x5e, 0x03, 0xd8 ]
expect AssertionDefect: discard StunMessage.decode(msgAttrFailed)

View File

@ -33,5 +33,5 @@ proc runTest(filename: string) =
exec excstr & " -r " & " tests/" & filename
rmFile "tests/" & filename.toExe
task test, "Run test":
runTest("runalltests")
# task test, "Run test":
# runTest("runalltests")

View File

@ -1,227 +0,0 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import tables
import chronos,
chronicles,
binary_serialization
import sctp
export binary_serialization
logScope:
topics = "webrtc datachannel"
# Implementation of the DataChannel protocol, mostly following
# https://www.rfc-editor.org/rfc/rfc8831.html and
# https://www.rfc-editor.org/rfc/rfc8832.html
type
DataChannelProtocolIds* {.size: 4.} = enum
WebRtcDcep = 50
WebRtcString = 51
WebRtcBinary = 53
WebRtcStringEmpty = 56
WebRtcBinaryEmpty = 57
DataChannelMessageType* {.size: 1.} = enum
Reserved = 0x00
Ack = 0x02
Open = 0x03
DataChannelMessage* = object
case messageType*: DataChannelMessageType
of Open: openMessage*: DataChannelOpenMessage
else: discard
DataChannelType {.size: 1.} = enum
Reliable = 0x00
PartialReliableRexmit = 0x01
PartialReliableTimed = 0x02
ReliableUnordered = 0x80
PartialReliableRexmitUnordered = 0x81
PartialReliableTimedUnorderd = 0x82
DataChannelOpenMessage* = object
channelType*: DataChannelType
priority*: uint16
reliabilityParameter*: uint32
labelLength* {.bin_value: it.label.len.}: uint16
protocolLength* {.bin_value: it.protocol.len.}: uint16
label* {.bin_len: it.labelLength.}: seq[byte]
protocol* {.bin_len: it.protocolLength.}: seq[byte]
proc ordered(t: DataChannelType): bool =
t in [Reliable, PartialReliableRexmit, PartialReliableTimed]
type
#TODO handle closing
DataChannelStream* = ref object
id: uint16
conn: SctpConn
reliability: DataChannelType
reliabilityParameter: uint32
receivedData: AsyncQueue[seq[byte]]
acked: bool
#TODO handle closing
DataChannelConnection* = ref object
readLoopFut: Future[void]
streams: Table[uint16, DataChannelStream]
streamId: uint16
conn*: SctpConn
incomingStreams: AsyncQueue[DataChannelStream]
proc read*(stream: DataChannelStream): Future[seq[byte]] {.async.} =
let x = await stream.receivedData.popFirst()
trace "read", length=x.len(), id=stream.id
return x
proc write*(stream: DataChannelStream, buf: seq[byte]) {.async.} =
trace "write", length=buf.len(), id=stream.id
var
sendInfo = SctpMessageParameters(
streamId: stream.id,
endOfRecord: true,
protocolId: uint32(WebRtcBinary)
)
if stream.acked:
sendInfo.unordered = not stream.reliability.ordered
#TODO add reliability params
if buf.len == 0:
trace "Datachannel write empty"
sendInfo.protocolId = uint32(WebRtcBinaryEmpty)
await stream.conn.write(@[0'u8], sendInfo)
else:
await stream.conn.write(buf, sendInfo)
proc sendControlMessage(stream: DataChannelStream, msg: DataChannelMessage) {.async.} =
let
encoded = Binary.encode(msg)
sendInfo = SctpMessageParameters(
streamId: stream.id,
endOfRecord: true,
protocolId: uint32(WebRtcDcep)
)
trace "send control message", msg
await stream.conn.write(encoded, sendInfo)
proc openStream*(
conn: DataChannelConnection,
noiseHandshake: bool,
reliability = Reliable, reliabilityParameter: uint32 = 0): Future[DataChannelStream] {.async.} =
let streamId: uint16 =
if not noiseHandshake:
let res = conn.streamId
conn.streamId += 2
res
else:
0
trace "open stream", streamId
if reliability in [Reliable, ReliableUnordered] and reliabilityParameter != 0:
raise newException(ValueError, "reliabilityParameter should be 0")
if streamId in conn.streams:
raise newException(ValueError, "streamId already used")
#TODO: we should request more streams when required
# https://github.com/sctplab/usrsctp/blob/a0cbf4681474fab1e89d9e9e2d5c3694fce50359/programs/rtcweb.c#L304C16-L304C16
var stream = DataChannelStream(
id: streamId, conn: conn.conn,
reliability: reliability,
reliabilityParameter: reliabilityParameter,
receivedData: newAsyncQueue[seq[byte]]()
)
conn.streams[streamId] = stream
let
msg = DataChannelMessage(
messageType: Open,
openMessage: DataChannelOpenMessage(
channelType: reliability,
reliabilityParameter: reliabilityParameter
)
)
await stream.sendControlMessage(msg)
return stream
proc handleData(conn: DataChannelConnection, msg: SctpMessage) =
let streamId = msg.params.streamId
trace "handle data message", streamId, ppid = msg.params.protocolId, data = msg.data
if streamId notin conn.streams:
raise newException(ValueError, "got data for unknown streamid")
let stream = conn.streams[streamId]
#TODO handle string vs binary
if msg.params.protocolId in [uint32(WebRtcStringEmpty), uint32(WebRtcBinaryEmpty)]:
# PPID indicate empty message
stream.receivedData.addLastNoWait(@[])
else:
stream.receivedData.addLastNoWait(msg.data)
proc handleControl(conn: DataChannelConnection, msg: SctpMessage) {.async.} =
let
decoded = Binary.decode(msg.data, DataChannelMessage)
streamId = msg.params.streamId
trace "handle control message", decoded, streamId = msg.params.streamId
if decoded.messageType == Ack:
if streamId notin conn.streams:
raise newException(ValueError, "got ack for unknown streamid")
conn.streams[streamId].acked = true
elif decoded.messageType == Open:
if streamId in conn.streams:
raise newException(ValueError, "got open for already existing streamid")
let stream = DataChannelStream(
id: streamId, conn: conn.conn,
reliability: decoded.openMessage.channelType,
reliabilityParameter: decoded.openMessage.reliabilityParameter,
receivedData: newAsyncQueue[seq[byte]]()
)
conn.streams[streamId] = stream
conn.incomingStreams.addLastNoWait(stream)
await stream.sendControlMessage(DataChannelMessage(messageType: Ack))
proc readLoop(conn: DataChannelConnection) {.async.} =
try:
while true:
let message = await conn.conn.read()
# TODO: check the protocolId
if message.params.protocolId == uint32(WebRtcDcep):
#TODO should we really await?
await conn.handleControl(message)
else:
conn.handleData(message)
except CatchableError as exc:
discard
proc accept*(conn: DataChannelConnection): Future[DataChannelStream] {.async.} =
return await conn.incomingStreams.popFirst()
proc new*(_: type DataChannelConnection, conn: SctpConn): DataChannelConnection =
result = DataChannelConnection(
conn: conn,
incomingStreams: newAsyncQueue[DataChannelStream](),
streamId: 1'u16 # TODO: Serveur == 1, client == 2
)
result.readLoopFut = result.readLoop()

View File

@ -1,377 +0,0 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import times, deques, tables, sequtils
import chronos, chronicles
import ./utils, ../stun/stun_connection
import mbedtls/ssl
import mbedtls/ssl_cookie
import mbedtls/ssl_cache
import mbedtls/pk
import mbedtls/md
import mbedtls/entropy
import mbedtls/ctr_drbg
import mbedtls/rsa
import mbedtls/x509
import mbedtls/x509_crt
import mbedtls/bignum
import mbedtls/error
import mbedtls/net_sockets
import mbedtls/timing
logScope:
topics = "webrtc dtls"
# Implementation of a DTLS client and a DTLS Server by using the mbedtls library.
# Multiple things here are unintuitive partly because of the callbacks
# used by mbedtls and that those callbacks cannot be async.
#
# TODO:
# - Check the viability of the add/pop first/last of the asyncqueue with the limit.
# There might be some errors (or crashes) with some edge cases with the no wait option
# - Not critical - Check how to make a better use of MBEDTLS_ERR_SSL_WANT_WRITE
# - Not critical - May be interesting to split Dtls and DtlsConn into two files
# This limit is arbitrary, it could be interesting to make it configurable.
const PendingHandshakeLimit = 1024
# -- DtlsConn --
# A Dtls connection to a specific IP address recovered by the receiving part of
# the Udp "connection"
type
DtlsError* = object of CatchableError
DtlsConn* = ref object
conn: StunConn
laddr: TransportAddress
raddr*: TransportAddress
dataRecv: AsyncQueue[seq[byte]]
sendFuture: Future[void]
closed: bool
closeEvent: AsyncEvent
timer: mbedtls_timing_delay_context
ssl: mbedtls_ssl_context
config: mbedtls_ssl_config
cookie: mbedtls_ssl_cookie_ctx
cache: mbedtls_ssl_cache_context
ctr_drbg: mbedtls_ctr_drbg_context
entropy: mbedtls_entropy_context
localCert: seq[byte]
remoteCert: seq[byte]
proc init(self: DtlsConn, conn: StunConn, laddr: TransportAddress) =
self.conn = conn
self.laddr = laddr
self.dataRecv = newAsyncQueue[seq[byte]]()
self.closed = false
self.closeEvent = newAsyncEvent()
proc join(self: DtlsConn) {.async.} =
await self.closeEvent.wait()
proc dtlsHandshake(self: DtlsConn, isServer: bool) {.async.} =
var shouldRead = isServer
while self.ssl.private_state != MBEDTLS_SSL_HANDSHAKE_OVER:
if shouldRead:
if isServer:
case self.raddr.family
of AddressFamily.IPv4:
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v4)
of AddressFamily.IPv6:
mb_ssl_set_client_transport_id(self.ssl, self.raddr.address_v6)
else:
raise newException(DtlsError, "Remote address isn't an IP address")
let tmp = await self.dataRecv.popFirst()
self.dataRecv.addFirstNoWait(tmp)
self.sendFuture = nil
let res = mb_ssl_handshake_step(self.ssl)
if not self.sendFuture.isNil():
await self.sendFuture
shouldRead = false
if res == MBEDTLS_ERR_SSL_WANT_WRITE:
continue
elif res == MBEDTLS_ERR_SSL_WANT_READ:
shouldRead = true
continue
elif res == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
mb_ssl_session_reset(self.ssl)
shouldRead = isServer
continue
elif res != 0:
raise newException(DtlsError, $(res.mbedtls_high_level_strerr()))
proc close*(self: DtlsConn) {.async.} =
if self.closed:
debug "Try to close DtlsConn twice"
return
self.closed = true
self.sendFuture = nil
# TODO: proc mbedtls_ssl_close_notify => template mb_ssl_close_notify in nim-mbedtls
let x = mbedtls_ssl_close_notify(addr self.ssl)
if not self.sendFuture.isNil():
await self.sendFuture
self.closeEvent.fire()
proc write*(self: DtlsConn, msg: seq[byte]) {.async.} =
if self.closed:
debug "Try to write on an already closed DtlsConn"
return
var buf = msg
try:
let sendFuture = newFuture[void]("DtlsConn write")
self.sendFuture = nil
let write = mb_ssl_write(self.ssl, buf)
if not self.sendFuture.isNil():
await self.sendFuture
trace "Dtls write", msgLen = msg.len(), actuallyWrote = write
except MbedTLSError as exc:
trace "Dtls write error", errorMsg = exc.msg
raise exc
proc read*(self: DtlsConn): Future[seq[byte]] {.async.} =
if self.closed:
debug "Try to read on an already closed DtlsConn"
return
var res = newSeq[byte](8192)
while true:
let tmp = await self.dataRecv.popFirst()
self.dataRecv.addFirstNoWait(tmp)
# TODO: Find a clear way to use the template `mb_ssl_read` without
# messing up things with exception
let length = mbedtls_ssl_read(addr self.ssl, cast[ptr byte](addr res[0]), res.len().uint)
if length == MBEDTLS_ERR_SSL_WANT_READ:
continue
if length < 0:
raise newException(DtlsError, $(length.cint.mbedtls_high_level_strerr()))
res.setLen(length)
return res
# -- Dtls --
# The Dtls object read every messages from the UdpConn/StunConn and, if the address
# is not yet stored in the Table `Connection`, adds it to the `pendingHandshake` queue
# to be accepted later, if the address is stored, add the message received to the
# corresponding DtlsConn `dataRecv` queue.
type
Dtls* = ref object of RootObj
connections: Table[TransportAddress, DtlsConn]
pendingHandshakes: AsyncQueue[(TransportAddress, seq[byte])]
conn: StunConn
laddr: TransportAddress
started: bool
readLoop: Future[void]
ctr_drbg: mbedtls_ctr_drbg_context
entropy: mbedtls_entropy_context
serverPrivKey: mbedtls_pk_context
serverCert: mbedtls_x509_crt
localCert: seq[byte]
proc updateOrAdd(aq: AsyncQueue[(TransportAddress, seq[byte])],
raddr: TransportAddress, buf: seq[byte]) =
for kv in aq.mitems():
if kv[0] == raddr:
kv[1] = buf
return
aq.addLastNoWait((raddr, buf))
proc init*(self: Dtls, conn: StunConn, laddr: TransportAddress) =
if self.started:
warn "Already started"
return
proc readLoop() {.async.} =
while true:
let (buf, raddr) = await self.conn.read()
if self.connections.hasKey(raddr):
self.connections[raddr].dataRecv.addLastNoWait(buf)
else:
self.pendingHandshakes.updateOrAdd(raddr, buf)
self.connections = initTable[TransportAddress, DtlsConn]()
self.pendingHandshakes = newAsyncQueue[(TransportAddress, seq[byte])](PendingHandshakeLimit)
self.conn = conn
self.laddr = laddr
self.started = true
self.readLoop = readLoop()
mb_ctr_drbg_init(self.ctr_drbg)
mb_entropy_init(self.entropy)
mb_ctr_drbg_seed(self.ctr_drbg, mbedtls_entropy_func, self.entropy, nil, 0)
self.serverPrivKey = self.ctr_drbg.generateKey()
self.serverCert = self.ctr_drbg.generateCertificate(self.serverPrivKey)
self.localCert = newSeq[byte](self.serverCert.raw.len)
copyMem(addr self.localCert[0], self.serverCert.raw.p, self.serverCert.raw.len)
proc stop*(self: Dtls) {.async.} =
if not self.started:
warn "Already stopped"
return
await allFutures(toSeq(self.connections.values()).mapIt(it.close()))
self.readLoop.cancel()
self.started = false
# -- Remote / Local certificate getter --
proc remoteCertificate*(conn: DtlsConn): seq[byte] =
conn.remoteCert
proc localCertificate*(conn: DtlsConn): seq[byte] =
conn.localCert
proc localCertificate*(self: Dtls): seq[byte] =
self.localCert
# -- MbedTLS Callbacks --
proc verify(ctx: pointer, pcert: ptr mbedtls_x509_crt,
state: cint, pflags: ptr uint32): cint {.cdecl.} =
# verify is the procedure called by mbedtls when receiving the remote
# certificate. It's usually used to verify the validity of the certificate.
# We use this procedure to store the remote certificate as it's mandatory
# to have it for the Prologue of the Noise protocol, aswell as the localCertificate.
var self = cast[DtlsConn](ctx)
let cert = pcert[]
self.remoteCert = newSeq[byte](cert.raw.len)
copyMem(addr self.remoteCert[0], cert.raw.p, cert.raw.len)
return 0
proc dtlsSend(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
# dtlsSend is the procedure called by mbedtls when data needs to be sent.
# As the StunConn's write proc is asynchronous and dtlsSend cannot be async,
# we store the future of this write and await it after the end of the
# function (see write or dtlsHanshake for example).
var self = cast[DtlsConn](ctx)
var toWrite = newSeq[byte](len)
if len > 0:
copyMem(addr toWrite[0], buf, len)
trace "dtls send", len
self.sendFuture = self.conn.write(self.raddr, toWrite)
result = len.cint
proc dtlsRecv(ctx: pointer, buf: ptr byte, len: uint): cint {.cdecl.} =
# dtlsRecv is the procedure called by mbedtls when data needs to be received.
# As we cannot asynchronously await for data to be received, we use a data received
# queue. If this queue is empty, we return `MBEDTLS_ERR_SSL_WANT_READ` for us to await
# when the mbedtls proc resumed (see read or dtlsHandshake for example)
let self = cast[DtlsConn](ctx)
if self.dataRecv.len() == 0:
return MBEDTLS_ERR_SSL_WANT_READ
var dataRecv = self.dataRecv.popFirstNoWait()
copyMem(buf, addr dataRecv[0], dataRecv.len())
result = dataRecv.len().cint
trace "dtls receive", len, result
# -- Dtls Accept / Connect procedures --
proc removeConnection(self: Dtls, conn: DtlsConn, raddr: TransportAddress) {.async.} =
await conn.join()
self.connections.del(raddr)
proc accept*(self: Dtls): Future[DtlsConn] {.async.} =
var res = DtlsConn()
res.init(self.conn, self.laddr)
mb_ssl_init(res.ssl)
mb_ssl_config_init(res.config)
mb_ssl_cookie_init(res.cookie)
mb_ssl_cache_init(res.cache)
res.ctr_drbg = self.ctr_drbg
res.entropy = self.entropy
var pkey = self.serverPrivKey
var srvcert = self.serverCert
res.localCert = self.localCert
mb_ssl_config_defaults(res.config,
MBEDTLS_SSL_IS_SERVER,
MBEDTLS_SSL_TRANSPORT_DATAGRAM,
MBEDTLS_SSL_PRESET_DEFAULT)
mb_ssl_conf_rng(res.config, mbedtls_ctr_drbg_random, res.ctr_drbg)
mb_ssl_conf_read_timeout(res.config, 10000) # in milliseconds
mb_ssl_conf_ca_chain(res.config, srvcert.next, nil)
mb_ssl_conf_own_cert(res.config, srvcert, pkey)
mb_ssl_cookie_setup(res.cookie, mbedtls_ctr_drbg_random, res.ctr_drbg)
mb_ssl_conf_dtls_cookies(res.config, res.cookie)
mb_ssl_set_timer_cb(res.ssl, res.timer)
mb_ssl_setup(res.ssl, res.config)
mb_ssl_session_reset(res.ssl)
mb_ssl_set_verify(res.ssl, verify, res)
mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
mb_ssl_set_bio(res.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil)
while true:
let (raddr, buf) = await self.pendingHandshakes.popFirst()
try:
res.raddr = raddr
res.dataRecv.addLastNoWait(buf)
self.connections[raddr] = res
await res.dtlsHandshake(true)
asyncSpawn self.removeConnection(res, raddr)
break
except CatchableError as exc:
trace "Handshake fail", remoteAddress = raddr, error = exc.msg
self.connections.del(raddr)
continue
return res
proc connect*(self: Dtls, raddr: TransportAddress): Future[DtlsConn] {.async.} =
var res = DtlsConn()
res.init(self.conn, self.laddr)
mb_ssl_init(res.ssl)
mb_ssl_config_init(res.config)
res.ctr_drbg = self.ctr_drbg
res.entropy = self.entropy
var pkey = res.ctr_drbg.generateKey()
var srvcert = res.ctr_drbg.generateCertificate(pkey)
res.localCert = newSeq[byte](srvcert.raw.len)
copyMem(addr res.localCert[0], srvcert.raw.p, srvcert.raw.len)
mb_ctr_drbg_init(res.ctr_drbg)
mb_entropy_init(res.entropy)
mb_ctr_drbg_seed(res.ctr_drbg, mbedtls_entropy_func, res.entropy, nil, 0)
mb_ssl_config_defaults(res.config,
MBEDTLS_SSL_IS_CLIENT,
MBEDTLS_SSL_TRANSPORT_DATAGRAM,
MBEDTLS_SSL_PRESET_DEFAULT)
mb_ssl_conf_rng(res.config, mbedtls_ctr_drbg_random, res.ctr_drbg)
mb_ssl_conf_read_timeout(res.config, 10000) # in milliseconds
mb_ssl_conf_ca_chain(res.config, srvcert.next, nil)
mb_ssl_set_timer_cb(res.ssl, res.timer)
mb_ssl_setup(res.ssl, res.config)
mb_ssl_set_verify(res.ssl, verify, res)
mb_ssl_conf_authmode(res.config, MBEDTLS_SSL_VERIFY_OPTIONAL)
mb_ssl_set_bio(res.ssl, cast[pointer](res), dtlsSend, dtlsRecv, nil)
res.raddr = raddr
self.connections[raddr] = res
try:
await res.dtlsHandshake(false)
asyncSpawn self.removeConnection(res, raddr)
except CatchableError as exc:
trace "Handshake fail", remoteAddress = raddr, error = exc.msg
self.connections.del(raddr)
raise exc
return res

View File

@ -1,96 +0,0 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import std/times
import stew/byteutils
import mbedtls/pk
import mbedtls/rsa
import mbedtls/ctr_drbg
import mbedtls/x509_crt
import mbedtls/bignum
import mbedtls/md
import chronicles
# This sequence is used for debugging.
const mb_ssl_states* = @[
"MBEDTLS_SSL_HELLO_REQUEST",
"MBEDTLS_SSL_CLIENT_HELLO",
"MBEDTLS_SSL_SERVER_HELLO",
"MBEDTLS_SSL_SERVER_CERTIFICATE",
"MBEDTLS_SSL_SERVER_KEY_EXCHANGE",
"MBEDTLS_SSL_CERTIFICATE_REQUEST",
"MBEDTLS_SSL_SERVER_HELLO_DONE",
"MBEDTLS_SSL_CLIENT_CERTIFICATE",
"MBEDTLS_SSL_CLIENT_KEY_EXCHANGE",
"MBEDTLS_SSL_CERTIFICATE_VERIFY",
"MBEDTLS_SSL_CLIENT_CHANGE_CIPHER_SPEC",
"MBEDTLS_SSL_CLIENT_FINISHED",
"MBEDTLS_SSL_SERVER_CHANGE_CIPHER_SPEC",
"MBEDTLS_SSL_SERVER_FINISHED",
"MBEDTLS_SSL_FLUSH_BUFFERS",
"MBEDTLS_SSL_HANDSHAKE_WRAPUP",
"MBEDTLS_SSL_NEW_SESSION_TICKET",
"MBEDTLS_SSL_SERVER_HELLO_VERIFY_REQUEST_SENT",
"MBEDTLS_SSL_HELLO_RETRY_REQUEST",
"MBEDTLS_SSL_ENCRYPTED_EXTENSIONS",
"MBEDTLS_SSL_END_OF_EARLY_DATA",
"MBEDTLS_SSL_CLIENT_CERTIFICATE_VERIFY",
"MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED",
"MBEDTLS_SSL_CLIENT_CCS_BEFORE_2ND_CLIENT_HELLO",
"MBEDTLS_SSL_SERVER_CCS_AFTER_SERVER_HELLO",
"MBEDTLS_SSL_CLIENT_CCS_AFTER_CLIENT_HELLO",
"MBEDTLS_SSL_SERVER_CCS_AFTER_HELLO_RETRY_REQUEST",
"MBEDTLS_SSL_HANDSHAKE_OVER",
"MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET",
"MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET_FLUSH"
]
template generateKey*(random: mbedtls_ctr_drbg_context): mbedtls_pk_context =
var res: mbedtls_pk_context
mb_pk_init(res)
discard mbedtls_pk_setup(addr res, mbedtls_pk_info_from_type(MBEDTLS_PK_RSA))
mb_rsa_gen_key(mb_pk_rsa(res), mbedtls_ctr_drbg_random, random, 2048, 65537)
let x = mb_pk_rsa(res)
res
template generateCertificate*(random: mbedtls_ctr_drbg_context,
issuer_key: mbedtls_pk_context): mbedtls_x509_crt =
let
# To be honest, I have no clue what to put here as a name
name = "C=FR,O=Status,CN=webrtc"
time_format = initTimeFormat("YYYYMMddHHmmss")
time_from = times.now().format(time_format)
time_to = (times.now() + times.years(1)).format(time_format)
var write_cert: mbedtls_x509write_cert
var serial_mpi: mbedtls_mpi
mb_x509write_crt_init(write_cert)
mb_x509write_crt_set_md_alg(write_cert, MBEDTLS_MD_SHA256);
mb_x509write_crt_set_subject_key(write_cert, issuer_key)
mb_x509write_crt_set_issuer_key(write_cert, issuer_key)
mb_x509write_crt_set_subject_name(write_cert, name)
mb_x509write_crt_set_issuer_name(write_cert, name)
mb_x509write_crt_set_validity(write_cert, time_from, time_to)
mb_x509write_crt_set_basic_constraints(write_cert, 0, -1)
mb_x509write_crt_set_subject_key_identifier(write_cert)
mb_x509write_crt_set_authority_key_identifier(write_cert)
mb_mpi_init(serial_mpi)
let serial_hex = mb_mpi_read_string(serial_mpi, 16)
mb_x509write_crt_set_serial(write_cert, serial_mpi)
let buf =
try:
mb_x509write_crt_pem(write_cert, 2048, mbedtls_ctr_drbg_random, random)
except MbedTLSError as e:
raise e
var res: mbedtls_x509_crt
mb_x509_crt_parse(res, buf)
res

View File

@ -1,406 +0,0 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import tables, bitops, posix, strutils, sequtils
import chronos, chronicles, stew/[ranges/ptr_arith, byteutils, endians2]
import usrsctp
import dtls/dtls
import binary_serialization
export chronicles
logScope:
topics = "webrtc sctp"
# Implementation of an Sctp client and server using the usrsctp library.
# Usrsctp is usable as a single thread but it's not the intended way to
# use it. There's a lot of callbacks calling each other in a synchronous
# way where we want to be able to call asynchronous procedure, but cannot.
# TODO:
# - Replace doAssert by a proper exception management
# - Find a clean way to manage SCTP ports
# - Unregister address when closing
proc perror(error: cstring) {.importc, cdecl, header: "<errno.h>".}
proc printf(format: cstring) {.cdecl, importc: "printf", varargs, header: "<stdio.h>", gcsafe.}
type
SctpError* = object of CatchableError
SctpState = enum
Connecting
Connected
Closed
SctpMessageParameters* = object
protocolId*: uint32
streamId*: uint16
endOfRecord*: bool
unordered*: bool
SctpMessage* = ref object
data*: seq[byte]
info: sctp_recvv_rn
params*: SctpMessageParameters
SctpConn* = ref object
conn*: DtlsConn
state: SctpState
connectEvent: AsyncEvent
acceptEvent: AsyncEvent
readLoop: Future[void]
sctp: Sctp
udp: DatagramTransport
address: TransportAddress
sctpSocket: ptr socket
dataRecv: AsyncQueue[SctpMessage]
sentFuture: Future[void]
Sctp* = ref object
dtls: Dtls
udp: DatagramTransport
connections: Table[TransportAddress, SctpConn]
gotConnection: AsyncEvent
timersHandler: Future[void]
isServer: bool
sockServer: ptr socket
pendingConnections: seq[SctpConn]
pendingConnections2: Table[SockAddr, SctpConn]
sentAddress: TransportAddress
sentFuture: Future[void]
# These three objects are used for debugging/trace only
SctpChunk = object
chunkType: uint8
flag: uint8
length {.bin_value: it.data.len() + 4.}: uint16
data {.bin_len: it.length - 4.}: seq[byte]
SctpPacketHeader = object
srcPort: uint16
dstPort: uint16
verifTag: uint32
checksum: uint32
SctpPacketStructure = object
header: SctpPacketHeader
chunks: seq[SctpChunk]
const IPPROTO_SCTP = 132
proc getSctpPacket(buffer: seq[byte]): SctpPacketStructure =
# Only used for debugging/trace
result.header = Binary.decode(buffer, SctpPacketHeader)
var size = sizeof(SctpPacketStructure)
while size < buffer.len:
let chunk = Binary.decode(buffer[size..^1], SctpChunk)
result.chunks.add(chunk)
size.inc(chunk.length.int)
while size mod 4 != 0:
# padding; could use `size.inc(-size %% 4)` instead but it lacks clarity
size.inc(1)
# -- Asynchronous wrapper --
template usrsctpAwait(self: SctpConn|Sctp, body: untyped): untyped =
# usrsctpAwait is template which set `sentFuture` to nil then calls (usually)
# an usrsctp function. If during the synchronous run of the usrsctp function
# `sendCallback` is called, then `sentFuture` is set and waited.
self.sentFuture = nil
when type(body) is void:
body
if self.sentFuture != nil: await self.sentFuture
else:
let res = body
if self.sentFuture != nil: await self.sentFuture
res
# -- SctpConn --
proc new(T: typedesc[SctpConn], conn: DtlsConn, sctp: Sctp): T =
T(conn: conn,
sctp: sctp,
state: Connecting,
connectEvent: AsyncEvent(),
acceptEvent: AsyncEvent(),
dataRecv: newAsyncQueue[SctpMessage]() # TODO add some limit for backpressure?
)
proc read*(self: SctpConn): Future[SctpMessage] {.async.} =
# Used by DataChannel, returns SctpMessage in order to get the stream
# and protocol ids
return await self.dataRecv.popFirst()
proc toFlags(params: SctpMessageParameters): uint16 =
if params.endOfRecord:
result = result or SCTP_EOR
if params.unordered:
result = result or SCTP_UNORDERED
proc write*(self: SctpConn, buf: seq[byte],
sendParams = default(SctpMessageParameters)) {.async.} =
# Used by DataChannel, writes buf on the Dtls connection.
trace "Write", buf
self.sctp.sentAddress = self.address
var cpy = buf
let sendvErr =
if sendParams == default(SctpMessageParameters):
# If writes is called by DataChannel, sendParams should never
# be the default value. This split is useful for testing.
self.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0,
nil, 0, SCTP_SENDV_NOINFO.cuint, 0)
else:
var sendInfo = sctp_sndinfo(
snd_sid: sendParams.streamId,
# TODO: swapBytes => htonl?
snd_ppid: sendParams.protocolId.swapBytes(),
snd_flags: sendParams.toFlags)
self.usrsctpAwait:
self.sctpSocket.usrsctp_sendv(cast[pointer](addr cpy[0]), cpy.len().uint, nil, 0,
cast[pointer](addr sendInfo), sizeof(sendInfo).SockLen,
SCTP_SENDV_SNDINFO.cuint, 0)
if sendvErr < 0:
# TODO: throw an exception
perror("usrsctp_sendv")
proc write*(self: SctpConn, s: string) {.async.} =
await self.write(s.toBytes())
proc close*(self: SctpConn) {.async.} =
self.usrsctpAwait:
self.sctpSocket.usrsctp_close()
# -- usrsctp receive data callbacks --
proc handleUpcall(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
# Callback procedure called when we receive data after
# connection has been established.
let
conn = cast[SctpConn](data)
events = usrsctp_get_events(sock)
trace "Handle Upcall", events
if bitand(events, SCTP_EVENT_READ) != 0:
var
message = SctpMessage(
data: newSeq[byte](4096)
)
address: Sockaddr_storage
rn: sctp_recvv_rn
addressLen = sizeof(Sockaddr_storage).SockLen
rnLen = sizeof(sctp_recvv_rn).SockLen
infotype: uint
flags: int
let n = sock.usrsctp_recvv(cast[pointer](addr message.data[0]),
message.data.len.uint,
cast[ptr SockAddr](addr address),
cast[ptr SockLen](addr addressLen),
cast[pointer](addr message.info),
cast[ptr SockLen](addr rnLen),
cast[ptr cuint](addr infotype),
cast[ptr cint](addr flags))
if n < 0:
perror("usrsctp_recvv")
return
elif n > 0:
# It might be necessary to check if infotype == SCTP_RECVV_RCVINFO
message.data.delete(n..<message.data.len())
trace "message info from handle upcall", msginfo = message.info
message.params = SctpMessageParameters(
protocolId: message.info.recvv_rcvinfo.rcv_ppid.swapBytes(),
streamId: message.info.recvv_rcvinfo.rcv_sid
)
if bitand(flags, MSG_NOTIFICATION) != 0:
trace "Notification received", length = n
else:
try:
conn.dataRecv.addLastNoWait(message)
except AsyncQueueFullError:
trace "Queue full, dropping packet"
elif bitand(events, SCTP_EVENT_WRITE) != 0:
trace "sctp event write in the upcall"
else:
warn "Handle Upcall unexpected event", events
proc handleAccept(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
# Callback procedure called when accepting a connection.
trace "Handle Accept"
var
sconn: Sockaddr_conn
slen: Socklen = sizeof(Sockaddr_conn).uint32
let
sctp = cast[Sctp](data)
# TODO: check if sctpSocket != nil
sctpSocket = usrsctp_accept(sctp.sockServer, cast[ptr SockAddr](addr sconn), addr slen)
let conn = cast[SctpConn](sconn.sconn_addr)
conn.sctpSocket = sctpSocket
conn.state = Connected
var nodelay: uint32 = 1
var recvinfo: uint32 = 1
doAssert 0 == sctpSocket.usrsctp_set_non_blocking(1)
doAssert 0 == conn.sctpSocket.usrsctp_set_upcall(handleUpcall, cast[pointer](conn))
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY,
addr nodelay, sizeof(nodelay).SockLen)
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO,
addr recvinfo, sizeof(recvinfo).SockLen)
conn.acceptEvent.fire()
proc handleConnect(sock: ptr socket, data: pointer, flags: cint) {.cdecl.} =
# Callback procedure called when connecting
trace "Handle Connect"
let
conn = cast[SctpConn](data)
events = usrsctp_get_events(sock)
trace "Handle Upcall", events, state = conn.state
if conn.state == Connecting:
if bitand(events, SCTP_EVENT_ERROR) != 0:
warn "Cannot connect", address = conn.address
conn.state = Closed
elif bitand(events, SCTP_EVENT_WRITE) != 0:
conn.state = Connected
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleUpcall, data)
conn.connectEvent.fire()
else:
warn "should be connecting", currentState = conn.state
# -- usrsctp send data callback --
proc sendCallback(ctx: pointer,
buffer: pointer,
length: uint,
tos: uint8,
set_df: uint8): cint {.cdecl.} =
let data = usrsctp_dumppacket(buffer, length, SCTP_DUMP_OUTBOUND)
if data != nil:
trace "sendCallback", sctpPacket = data.getSctpPacket(), length
usrsctp_freedumpbuffer(data)
let sctpConn = cast[SctpConn](ctx)
let buf = @(buffer.makeOpenArray(byte, int(length)))
proc testSend() {.async.} =
try:
trace "Send To", address = sctpConn.address
await sctpConn.conn.write(buf)
except CatchableError as exc:
trace "Send Failed", message = exc.msg
sctpConn.sentFuture = testSend()
# -- Sctp --
proc timersHandler() {.async.} =
while true:
await sleepAsync(500.milliseconds)
usrsctp_handle_timers(500)
proc stopServer*(self: Sctp) =
if not self.isServer:
trace "Try to close a client"
return
self.isServer = false
let pcs = self.pendingConnections
self.pendingConnections = @[]
for pc in pcs:
pc.sctpSocket.usrsctp_close()
self.sockServer.usrsctp_close()
proc init*(self: Sctp, dtls: Dtls, laddr: TransportAddress) =
self.gotConnection = newAsyncEvent()
self.timersHandler = timersHandler()
self.dtls = dtls
usrsctp_init_nothreads(laddr.port.uint16, sendCallback, printf)
discard usrsctp_sysctl_set_sctp_debug_on(SCTP_DEBUG_NONE)
discard usrsctp_sysctl_set_sctp_ecn_enable(1)
usrsctp_register_address(cast[pointer](self))
proc stop*(self: Sctp) {.async.} =
# TODO: close every connections
discard self.usrsctpAwait usrsctp_finish()
self.udp.close()
proc readLoopProc(res: SctpConn) {.async.} =
while true:
let
msg = await res.conn.read()
data = usrsctp_dumppacket(unsafeAddr msg[0], uint(msg.len), SCTP_DUMP_INBOUND)
if not data.isNil():
trace "Receive data", remoteAddress = res.conn.raddr,
sctpPacket = data.getSctpPacket()
usrsctp_freedumpbuffer(data)
usrsctp_conninput(cast[pointer](res), unsafeAddr msg[0], uint(msg.len), 0)
proc accept*(self: Sctp): Future[SctpConn] {.async.} =
if not self.isServer:
raise newException(SctpError, "Not a server")
var res = SctpConn.new(await self.dtls.accept(), self)
usrsctp_register_address(cast[pointer](res))
res.readLoop = res.readLoopProc()
res.acceptEvent.clear()
await res.acceptEvent.wait()
return res
proc listen*(self: Sctp, sctpPort: uint16 = 5000) =
if self.isServer:
trace "Try to start the server twice"
return
self.isServer = true
trace "Listening", sctpPort
doAssert 0 == usrsctp_sysctl_set_sctp_blackhole(2)
doAssert 0 == usrsctp_sysctl_set_sctp_no_csum_on_loopback(0)
doAssert 0 == usrsctp_sysctl_set_sctp_delayed_sack_time_default(0)
let sock = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
var on: int = 1
doAssert 0 == usrsctp_set_non_blocking(sock, 1)
var sin: Sockaddr_in
sin.sin_family = posix.AF_INET.uint16
sin.sin_port = htons(sctpPort)
sin.sin_addr.s_addr = htonl(INADDR_ANY)
doAssert 0 == usrsctp_bind(sock, cast[ptr SockAddr](addr sin), SockLen(sizeof(Sockaddr_in)))
doAssert 0 >= usrsctp_listen(sock, 1)
doAssert 0 == sock.usrsctp_set_upcall(handleAccept, cast[pointer](self))
self.sockServer = sock
proc connect*(self: Sctp,
address: TransportAddress,
sctpPort: uint16 = 5000): Future[SctpConn] {.async.} =
let
sctpSocket = usrsctp_socket(AF_CONN, posix.SOCK_STREAM, IPPROTO_SCTP, nil, nil, 0, nil)
conn = SctpConn.new(await self.dtls.connect(address), self)
trace "Create Connection", address
conn.sctpSocket = sctpSocket
conn.state = Connected
var nodelay: uint32 = 1
var recvinfo: uint32 = 1
doAssert 0 == usrsctp_set_non_blocking(conn.sctpSocket, 1)
doAssert 0 == usrsctp_set_upcall(conn.sctpSocket, handleConnect, cast[pointer](conn))
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_NODELAY,
addr nodelay, sizeof(nodelay).SockLen)
doAssert 0 == conn.sctpSocket.usrsctp_setsockopt(IPPROTO_SCTP, SCTP_RECVRCVINFO,
addr recvinfo, sizeof(recvinfo).SockLen)
var sconn: Sockaddr_conn
sconn.sconn_family = AF_CONN
sconn.sconn_port = htons(sctpPort)
sconn.sconn_addr = cast[pointer](conn)
self.sentAddress = address
usrsctp_register_address(cast[pointer](conn))
conn.readLoop = conn.readLoopProc()
let connErr = self.usrsctpAwait:
conn.sctpSocket.usrsctp_connect(cast[ptr SockAddr](addr sconn), SockLen(sizeof(sconn)))
doAssert 0 == connErr or errno == posix.EINPROGRESS, ($errno)
conn.state = Connecting
conn.connectEvent.clear()
await conn.connectEvent.wait()
# TODO: check connection state, if closed throw an exception
self.connections[address] = conn
return conn

View File

@ -1,151 +0,0 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import bitops, strutils
import chronos,
chronicles,
binary_serialization,
stew/objects,
stew/byteutils
import stun_attributes
export binary_serialization
logScope:
topics = "webrtc stun"
const
msgHeaderSize = 20
magicCookieSeq = @[ 0x21'u8, 0x12, 0xa4, 0x42 ]
magicCookie = 0x2112a442
BindingRequest = 0x0001'u16
BindingResponse = 0x0101'u16
proc decode(T: typedesc[RawStunAttribute], cnt: seq[byte]): seq[RawStunAttribute] =
const pad = @[0, 3, 2, 1]
var padding = 0
while padding < cnt.len():
let attr = Binary.decode(cnt[padding ..^ 1], RawStunAttribute)
result.add(attr)
padding += 4 + attr.value.len()
padding += pad[padding mod 4]
type
# Stun Header
# 0 1 2 3
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# |0 0| STUN Message Type | Message Length |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# | Magic Cookie |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# | |
# | Transaction ID (96 bits) |
# | |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# Message type:
# 0x0001: Binding Request
# 0x0101: Binding Response
# 0x0111: Binding Error Response
# 0x0002: Shared Secret Request
# 0x0102: Shared Secret Response
# 0x0112: Shared Secret Error Response
RawStunMessage = object
msgType: uint16
length* {.bin_value: it.content.len().}: uint16
magicCookie: uint32
transactionId: array[12, byte] # Down from 16 to 12 bytes in RFC5389
content* {.bin_len: it.length.}: seq[byte]
StunMessage* = object
msgType*: uint16
transactionId*: array[12, byte]
attributes*: seq[RawStunAttribute]
Stun* = object
proc getAttribute(attrs: seq[RawStunAttribute], typ: uint16): Option[seq[byte]] =
for attr in attrs:
if attr.attributeType == typ:
return some(attr.value)
return none(seq[byte])
proc isMessage*(T: typedesc[Stun], msg: seq[byte]): bool =
msg.len >= msgHeaderSize and msg[4..<8] == magicCookieSeq and bitand(0xC0'u8, msg[0]) == 0'u8
proc addLength(msgEncoded: var seq[byte], length: uint16) =
let
hi = (length div 256'u16).uint8
lo = (length mod 256'u16).uint8
msgEncoded[2] = msgEncoded[2] + hi
if msgEncoded[3].int + lo.int >= 256:
msgEncoded[2] = msgEncoded[2] + 1
msgEncoded[3] = ((msgEncoded[3].int + lo.int) mod 256).uint8
else:
msgEncoded[3] = msgEncoded[3] + lo
proc decode*(T: typedesc[StunMessage], msg: seq[byte]): StunMessage =
let smi = Binary.decode(msg, RawStunMessage)
return T(msgType: smi.msgType,
transactionId: smi.transactionId,
attributes: RawStunAttribute.decode(smi.content))
proc encode*(msg: StunMessage, userOpt: Option[seq[byte]] = none(seq[byte])): seq[byte] =
const pad = @[0, 3, 2, 1]
var smi = RawStunMessage(msgType: msg.msgType,
magicCookie: magicCookie,
transactionId: msg.transactionId)
for attr in msg.attributes:
smi.content.add(Binary.encode(attr))
smi.content.add(newSeq[byte](pad[smi.content.len() mod 4]))
result = Binary.encode(smi)
if userOpt.isSome():
let username = string.fromBytes(userOpt.get())
let usersplit = username.split(":")
if usersplit.len() == 2 and usersplit[0].startsWith("libp2p+webrtc+v1/"):
result.addLength(24)
result.add(Binary.encode(MessageIntegrity.encode(result, toBytes(usersplit[0]))))
result.addLength(8)
result.add(Binary.encode(Fingerprint.encode(result)))
proc getResponse*(T: typedesc[Stun], msg: seq[byte],
ta: TransportAddress): Option[seq[byte]] =
if ta.family != AddressFamily.IPv4 and ta.family != AddressFamily.IPv6:
return none(seq[byte])
let sm =
try:
StunMessage.decode(msg)
except CatchableError as exc:
return none(seq[byte])
if sm.msgType != BindingRequest:
return none(seq[byte])
var res = StunMessage(msgType: BindingResponse,
transactionId: sm.transactionId)
var unknownAttr: seq[uint16]
for attr in sm.attributes:
let typ = attr.attributeType
if typ.isRequired() and typ notin StunAttributeEnum:
unknownAttr.add(typ)
if unknownAttr.len() > 0:
res.attributes.add(ErrorCode.encode(ECUnknownAttribute))
res.attributes.add(UnknownAttribute.encode(unknownAttr))
return some(res.encode(sm.attributes.getAttribute(AttrUsername.uint16)))
res.attributes.add(XorMappedAddress.encode(ta, sm.transactionId))
return some(res.encode(sm.attributes.getAttribute(AttrUsername.uint16)))
proc new*(T: typedesc[Stun]): T =
result = T()

View File

@ -1,228 +0,0 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import std/sha1, sequtils, typetraits, std/md5
import binary_serialization,
stew/byteutils,
chronos
# -- Utils --
proc createCrc32Table(): array[0..255, uint32] =
for i in 0..255:
var rem = i.uint32
for j in 0..7:
if (rem and 1) > 0:
rem = (rem shr 1) xor 0xedb88320'u32
else:
rem = rem shr 1
result[i] = rem
proc crc32(s: seq[byte]): uint32 =
# CRC-32 is used for the fingerprint attribute
# See https://datatracker.ietf.org/doc/html/rfc5389#section-15.5
const crc32table = createCrc32Table()
result = 0xffffffff'u32
for c in s:
result = (result shr 8) xor crc32table[(result and 0xff) xor c]
result = not result
proc hmacSha1(key: seq[byte], msg: seq[byte]): seq[byte] =
# HMAC-SHA1 is used for the message integrity attribute
# See https://datatracker.ietf.org/doc/html/rfc5389#section-15.4
let
keyPadded =
if len(key) > 64:
@(secureHash(key.mapIt(it.chr)).distinctBase)
elif key.len() < 64:
key.concat(newSeq[byte](64 - key.len()))
else:
key
innerHash = keyPadded.
mapIt(it xor 0x36'u8).
concat(msg).
mapIt(it.chr).
secureHash()
outerHash = keyPadded.
mapIt(it xor 0x5c'u8).
concat(@(innerHash.distinctBase)).
mapIt(it.chr).
secureHash()
return @(outerHash.distinctBase)
# -- Attributes --
# There are obviously some attributes implementation that are missing,
# it might be something to do eventually if we want to make this
# repository work for other project than nim-libp2p
#
# Stun Attribute
# 0 1 2 3
# 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# | Type | Length |
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
# | Value (variable) ....
# +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
type
StunAttributeEncodingError* = object of CatchableError
RawStunAttribute* = object
attributeType*: uint16
length* {.bin_value: it.value.len.}: uint16
value* {.bin_len: it.length.}: seq[byte]
StunAttributeEnum* = enum
AttrMappedAddress = 0x0001
AttrChangeRequest = 0x0003 # RFC5780 Nat Behavior Discovery
AttrSourceAddress = 0x0004 # Deprecated
AttrChangedAddress = 0x0005 # Deprecated
AttrUsername = 0x0006
AttrMessageIntegrity = 0x0008
AttrErrorCode = 0x0009
AttrUnknownAttributes = 0x000A
AttrChannelNumber = 0x000C # RFC5766 TURN
AttrLifetime = 0x000D # RFC5766 TURN
AttrXORPeerAddress = 0x0012 # RFC5766 TURN
AttrData = 0x0013 # RFC5766 TURN
AttrRealm = 0x0014
AttrNonce = 0x0015
AttrXORRelayedAddress = 0x0016 # RFC5766 TURN
AttrRequestedAddressFamily = 0x0017 # RFC6156
AttrEvenPort = 0x0018 # RFC5766 TURN
AttrRequestedTransport = 0x0019 # RFC5766 TURN
AttrDontFragment = 0x001A # RFC5766 TURN
AttrMessageIntegritySHA256 = 0x001C # RFC8489 STUN (v2)
AttrPasswordAlgorithm = 0x001D # RFC8489 STUN (v2)
AttrUserhash = 0x001E # RFC8489 STUN (v2)
AttrXORMappedAddress = 0x0020
AttrReservationToken = 0x0022 # RFC5766 TURN
AttrPriority = 0x0024 # RFC5245 ICE
AttrUseCandidate = 0x0025 # RFC5245 ICE
AttrPadding = 0x0026 # RFC5780 Nat Behavior Discovery
AttrResponsePort = 0x0027 # RFC5780 Nat Behavior Discovery
AttrConnectionID = 0x002a # RFC6062 TURN Extensions
AttrPasswordAlgorithms = 0x8002 # RFC8489 STUN (v2)
AttrAlternateDomain = 0x8003 # RFC8489 STUN (v2)
AttrSoftware = 0x8022
AttrAlternateServer = 0x8023
AttrCacheTimeout = 0x8027 # RFC5780 Nat Behavior Discovery
AttrFingerprint = 0x8028
AttrICEControlled = 0x8029 # RFC5245 ICE
AttrICEControlling = 0x802A # RFC5245 ICE
AttrResponseOrigin = 0x802b # RFC5780 Nat Behavior Discovery
AttrOtherAddress = 0x802C # RFC5780 Nat Behavior Discovery
AttrOrigin = 0x802F
proc isRequired*(typ: uint16): bool = typ <= 0x7FFF'u16
proc isOptional*(typ: uint16): bool = typ >= 0x8000'u16
# Error Code
# https://datatracker.ietf.org/doc/html/rfc5389#section-15.6
type
ErrorCodeEnum* = enum
ECTryAlternate = 300
ECBadRequest = 400
ECUnauthenticated = 401
ECUnknownAttribute = 420
ECStaleNonce = 438
ECServerError = 500
ErrorCode* = object
reserved1: uint16 # should be 0
reserved2 {.bin_bitsize: 5.}: uint8 # should be 0
class {.bin_bitsize: 3.}: uint8
number: uint8
reason: seq[byte]
proc encode*(T: typedesc[ErrorCode], code: ErrorCodeEnum, reason: string = ""): RawStunAttribute =
let
ec = T(class: (code.uint16 div 100'u16).uint8,
number: (code.uint16 mod 100'u16).uint8,
reason: reason.toBytes())
value = Binary.encode(ec)
result = RawStunAttribute(attributeType: AttrErrorCode.uint16,
length: value.len().uint16,
value: value)
# Unknown Attribute
# https://datatracker.ietf.org/doc/html/rfc5389#section-15.9
type
UnknownAttribute* = object
unknownAttr: seq[uint16]
proc encode*(T: typedesc[UnknownAttribute], unknownAttr: seq[uint16]): RawStunAttribute =
let
ua = T(unknownAttr: unknownAttr)
value = Binary.encode(ua)
result = RawStunAttribute(attributeType: AttrUnknownAttributes.uint16,
length: value.len().uint16,
value: value)
# Fingerprint
# https://datatracker.ietf.org/doc/html/rfc5389#section-15.5
type
Fingerprint* = object
crc32: uint32
proc encode*(T: typedesc[Fingerprint], msg: seq[byte]): RawStunAttribute =
let value = Binary.encode(T(crc32: crc32(msg) xor 0x5354554e'u32))
result = RawStunAttribute(attributeType: AttrFingerprint.uint16,
length: value.len().uint16,
value: value)
# Xor Mapped Address
# https://datatracker.ietf.org/doc/html/rfc5389#section-15.2
type
MappedAddressFamily {.size: 1.} = enum
MAFIPv4 = 0x01
MAFIPv6 = 0x02
XorMappedAddress* = object
reserved: uint8 # should be 0
family: MappedAddressFamily
port: uint16
address: seq[byte]
proc encode*(T: typedesc[XorMappedAddress], ta: TransportAddress,
tid: array[12, byte]): RawStunAttribute =
const magicCookie = @[ 0x21'u8, 0x12, 0xa4, 0x42 ]
let
(address, family) =
if ta.family == AddressFamily.IPv4:
var s = newSeq[uint8](4)
for i in 0..3:
s[i] = ta.address_v4[i] xor magicCookie[i]
(s, MAFIPv4)
else:
let magicCookieTid = magicCookie.concat(@tid)
var s = newSeq[uint8](16)
for i in 0..15:
s[i] = ta.address_v6[i] xor magicCookieTid[i]
(s, MAFIPv6)
xma = T(family: family, port: ta.port.distinctBase xor 0x2112'u16, address: address)
value = Binary.encode(xma)
result = RawStunAttribute(attributeType: AttrXORMappedAddress.uint16,
length: value.len().uint16,
value: value)
# Message Integrity
# https://datatracker.ietf.org/doc/html/rfc5389#section-15.4
type
MessageIntegrity* = object
msgInt: seq[byte]
proc encode*(T: typedesc[MessageIntegrity], msg: seq[byte], key: seq[byte]): RawStunAttribute =
let value = Binary.encode(T(msgInt: hmacSha1(key, msg)))
result = RawStunAttribute(attributeType: AttrMessageIntegrity.uint16,
length: value.len().uint16, value: value)

View File

@ -1,61 +0,0 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import chronos, chronicles
import ../udp_connection, stun
logScope:
topics = "webrtc stun"
# TODO: Work fine when behaves like a server, need to implement the client side
type
StunConn* = ref object
conn: UdpConn
laddr: TransportAddress
dataRecv: AsyncQueue[(seq[byte], TransportAddress)]
handlesFut: Future[void]
closed: bool
proc handles(self: StunConn) {.async.} =
while true:
let (msg, raddr) = await self.conn.read()
if Stun.isMessage(msg):
let res = Stun.getResponse(msg, self.laddr)
if res.isSome():
await self.conn.write(raddr, res.get())
else:
self.dataRecv.addLastNoWait((msg, raddr))
proc init*(self: StunConn, conn: UdpConn, laddr: TransportAddress) =
self.conn = conn
self.laddr = laddr
self.closed = false
self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()
self.handlesFut = self.handles()
proc close*(self: StunConn) {.async.} =
if self.closed:
debug "Try to close StunConn twice"
return
self.handlesFut.cancel() # check before?
await self.conn.close()
proc write*(self: StunConn, raddr: TransportAddress, msg: seq[byte]) {.async.} =
if self.closed:
debug "Try to write on an already closed StunConn"
return
await self.conn.write(raddr, msg)
proc read*(self: StunConn): Future[(seq[byte], TransportAddress)] {.async.} =
if self.closed:
debug "Try to read on an already closed StunConn"
return
return await self.dataRecv.popFirst()

View File

@ -1,57 +0,0 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import chronos, chronicles
logScope:
topics = "webrtc udp"
# UdpConn is a small wrapper of the chronos DatagramTransport.
# It's the simplest solution we found to store the message and
# the remote address used by the underlying protocols (dtls/sctp etc...)
type
UdpConn* = ref object
laddr*: TransportAddress
udp: DatagramTransport
dataRecv: AsyncQueue[(seq[byte], TransportAddress)]
closed: bool
proc init*(self: UdpConn, laddr: TransportAddress) =
self.laddr = laddr
self.closed = false
proc onReceive(udp: DatagramTransport, address: TransportAddress) {.async, gcsafe.} =
trace "UDP onReceive"
let msg = udp.getMessage()
self.dataRecv.addLastNoWait((msg, address))
self.dataRecv = newAsyncQueue[(seq[byte], TransportAddress)]()
self.udp = newDatagramTransport(onReceive, local = laddr)
proc close*(self: UdpConn) {.async.} =
if self.closed:
debug "Try to close UdpConn twice"
return
self.closed = true
self.udp.close()
proc write*(self: UdpConn, raddr: TransportAddress, msg: seq[byte]) {.async.} =
if self.closed:
debug "Try to write on an already closed UdpConn"
return
trace "UDP write", msg
await self.udp.sendTo(raddr, msg)
proc read*(self: UdpConn): Future[(seq[byte], TransportAddress)] {.async.} =
if self.closed:
debug "Try to read on an already closed UdpConn"
return
trace "UDP read"
return await self.dataRecv.popFirst()

View File

@ -1,44 +0,0 @@
# Nim-WebRTC
# Copyright (c) 2024 Status Research & Development GmbH
# Licensed under either of
# * Apache License, version 2.0, ([LICENSE-APACHE](LICENSE-APACHE))
# * MIT license ([LICENSE-MIT](LICENSE-MIT))
# at your option.
# This file may not be copied, modified, or distributed except according to
# those terms.
import chronos, chronicles
import udp_connection
import stun/stun_connection
import dtls/dtls
import sctp, datachannel
logScope:
topics = "webrtc"
type
WebRTC* = ref object
udp*: UdpConn
stun*: StunConn
dtls*: Dtls
sctp*: Sctp
port: int
proc new*(T: typedesc[WebRTC], address: TransportAddress): T =
result = T(udp: UdpConn(), stun: StunConn(), dtls: Dtls(), sctp: Sctp())
result.udp.init(address)
result.stun.init(result.udp, address)
result.dtls.init(result.stun, address)
result.sctp.init(result.dtls, address)
proc listen*(self: WebRTC) =
self.sctp.listen()
proc connect*(self: WebRTC, raddr: TransportAddress): Future[DataChannelConnection] {.async.} =
let sctpConn = await self.sctp.connect(raddr) # TODO: Port?
result = DataChannelConnection.new(sctpConn)
proc accept*(w: WebRTC): Future[DataChannelConnection] {.async.} =
let sctpConn = await w.sctp.accept()
result = DataChannelConnection.new(sctpConn)