Actually fix incremental sending

This commit is contained in:
Joshua Boniface 2024-09-30 16:30:23 -04:00
parent 5f7aa0b2d6
commit fb8561cc5d
3 changed files with 103 additions and 77 deletions

View File

@ -3771,7 +3771,7 @@ class API_VM_Snapshot_Receive_Block(Resource):
reqargs.get("volume"), reqargs.get("volume"),
reqargs.get("snapshot"), reqargs.get("snapshot"),
int(reqargs.get("size")), int(reqargs.get("size")),
flask.request.stream, flask.request,
) )
@RequestParser( @RequestParser(
@ -3846,7 +3846,7 @@ class API_VM_Snapshot_Receive_Block(Resource):
reqargs.get("volume"), reqargs.get("volume"),
reqargs.get("snapshot"), reqargs.get("snapshot"),
reqargs.get("source_snapshot"), reqargs.get("source_snapshot"),
flask.request.stream, flask.request,
) )
@RequestParser( @RequestParser(

View File

@ -1306,7 +1306,7 @@ def vm_flush_locks(zkhandler, vm):
@ZKConnection(config) @ZKConnection(config)
def vm_snapshot_receive_block_full(zkhandler, pool, volume, snapshot, size, stream): def vm_snapshot_receive_block_full(zkhandler, pool, volume, snapshot, size, request):
""" """
Receive an RBD volume from a remote system Receive an RBD volume from a remote system
""" """
@ -1350,7 +1350,7 @@ def vm_snapshot_receive_block_full(zkhandler, pool, volume, snapshot, size, stre
logger.info(f"Importing full snapshot {pool}/{volume}@{snapshot}") logger.info(f"Importing full snapshot {pool}/{volume}@{snapshot}")
while True: while True:
chunk = flask.request.stream.read(chunk_size) chunk = request.stream.read(chunk_size)
if not chunk: if not chunk:
break break
image.write(chunk, last_chunk) image.write(chunk, last_chunk)
@ -1360,10 +1360,12 @@ def vm_snapshot_receive_block_full(zkhandler, pool, volume, snapshot, size, stre
ioctx.close() ioctx.close()
cluster.shutdown() cluster.shutdown()
return {"message": "Successfully received RBD block device"}, 200
@ZKConnection(config) @ZKConnection(config)
def vm_snapshot_receive_block_diff( def vm_snapshot_receive_block_diff(
zkhandler, pool, volume, snapshot, source_snapshot, stream zkhandler, pool, volume, snapshot, source_snapshot, request
): ):
""" """
Receive an RBD volume from a remote system Receive an RBD volume from a remote system
@ -1376,22 +1378,29 @@ def vm_snapshot_receive_block_diff(
ioctx = cluster.open_ioctx(pool) ioctx = cluster.open_ioctx(pool)
image = rbd.Image(ioctx, volume) image = rbd.Image(ioctx, volume)
logger.info( if len(request.files) > 0:
f"Applying diff between {pool}/{volume}@{source_snapshot} and {snapshot}" logger.info(f"Applying {len(request.files)} RBD diff chunks for {snapshot}")
)
chunk = stream.read() for i in range(len(request.files)):
object_key = f"object_{i}"
# Extract the offset and length (8 bytes each) and the data if object_key in request.files:
offset = int.from_bytes(chunk[:8], "big") object_data = request.files[object_key].read()
length = int.from_bytes(chunk[8:16], "big") offset = int.from_bytes(object_data[:8], "big")
data = chunk[16 : 16 + length] length = int.from_bytes(object_data[8:16], "big")
image.write(data, offset) data = object_data[16 : 16 + length]
logger.info(f"Applying RBD diff chunk at {offset} ({length} bytes)")
image.write(data, offset)
else:
return {"message": "No data received"}, 400
image.close() image.close()
ioctx.close() ioctx.close()
cluster.shutdown() cluster.shutdown()
return {
"message": f"Successfully received {len(request.files)} RBD diff chunks"
}, 200
@ZKConnection(config) @ZKConnection(config)
def vm_snapshot_receive_block_createsnap(zkhandler, pool, volume, snapshot): def vm_snapshot_receive_block_createsnap(zkhandler, pool, volume, snapshot):
@ -1423,6 +1432,8 @@ def vm_snapshot_receive_block_createsnap(zkhandler, pool, volume, snapshot):
output = {"message": retdata.replace('"', "'")} output = {"message": retdata.replace('"', "'")}
return output, retcode return output, retcode
return {"message": "Successfully received VM configuration data"}, 200
@ZKConnection(config) @ZKConnection(config)
def vm_snapshot_receive_config(zkhandler, snapshot, vm_config, source_snapshot=None): def vm_snapshot_receive_config(zkhandler, snapshot, vm_config, source_snapshot=None):

View File

@ -3190,18 +3190,20 @@ def vm_worker_send_snapshot(
destination_api_timeout = (3.05, 172800) destination_api_timeout = (3.05, 172800)
destination_api_headers = { destination_api_headers = {
"X-Api-Key": destination_api_key, "X-Api-Key": destination_api_key,
"Content-Type": "application/octet-stream",
} }
session = requests.Session()
session.headers.update(destination_api_headers)
session.verify = destination_api_verify_ssl
session.timeout = destination_api_timeout
try: try:
# Hit the API root; this should return "PVC API version x" # Hit the API root; this should return "PVC API version x"
response = requests.get( response = session.get(
f"{destination_api_uri}/", f"{destination_api_uri}/",
timeout=destination_api_timeout, timeout=destination_api_timeout,
headers=None,
params=None, params=None,
data=None, data=None,
verify=destination_api_verify_ssl,
) )
if "PVC API" not in response.json().get("message"): if "PVC API" not in response.json().get("message"):
raise ValueError("Remote API is not a PVC API or incorrect URI given") raise ValueError("Remote API is not a PVC API or incorrect URI given")
@ -3225,13 +3227,10 @@ def vm_worker_send_snapshot(
return False return False
# Hit the API "/status" endpoint to validate API key and cluster status # Hit the API "/status" endpoint to validate API key and cluster status
response = requests.get( response = session.get(
f"{destination_api_uri}/status", f"{destination_api_uri}/status",
timeout=destination_api_timeout,
headers=destination_api_headers,
params=None, params=None,
data=None, data=None,
verify=destination_api_verify_ssl,
) )
destination_cluster_status = response.json() destination_cluster_status = response.json()
current_destination_pvc_version = destination_cluster_status.get( current_destination_pvc_version = destination_cluster_status.get(
@ -3260,13 +3259,10 @@ def vm_worker_send_snapshot(
return False return False
# Check if the VM already exists on the remote # Check if the VM already exists on the remote
response = requests.get( response = session.get(
f"{destination_api_uri}/vm/{domain}", f"{destination_api_uri}/vm/{domain}",
timeout=destination_api_timeout,
headers=destination_api_headers,
params=None, params=None,
data=None, data=None,
verify=destination_api_verify_ssl,
) )
destination_vm_status = response.json() destination_vm_status = response.json()
if type(destination_vm_status) is list and len(destination_vm_status) > 0: if type(destination_vm_status) is list and len(destination_vm_status) > 0:
@ -3374,18 +3370,12 @@ def vm_worker_send_snapshot(
"snapshot": snapshot_name, "snapshot": snapshot_name,
"source_snapshot": incremental_parent, "source_snapshot": incremental_parent,
} }
send_headers = {
"X-Api-Key": destination_api_key,
"Content-Type": "application/json",
}
try: try:
response = requests.post( response = session.post(
f"{destination_api_uri}/vm/{vm_name}/snapshot/receive/config", f"{destination_api_uri}/vm/{vm_name}/snapshot/receive/config",
timeout=destination_api_timeout, headers={"Content-Type": "application/json"},
headers=send_headers,
params=send_params, params=send_params,
json=vm_detail, json=vm_detail,
verify=destination_api_verify_ssl,
) )
response.raise_for_status() response.raise_for_status()
except Exception as e: except Exception as e:
@ -3441,13 +3431,10 @@ def vm_worker_send_snapshot(
) )
# Check if the volume exists on the target # Check if the volume exists on the target
response = requests.get( response = session.get(
f"{destination_api_uri}/storage/ceph/volume/{pool}/{volume}", f"{destination_api_uri}/storage/ceph/volume/{pool}/{volume}",
timeout=destination_api_timeout,
headers=destination_api_headers,
params=None, params=None,
data=None, data=None,
verify=destination_api_verify_ssl,
) )
if response.status_code != 404: if response.status_code != 404:
fail( fail(
@ -3467,7 +3454,7 @@ def vm_worker_send_snapshot(
if incremental_parent is not None: if incremental_parent is not None:
# Diff between incremental_parent and snapshot # Diff between incremental_parent and snapshot
celery_message = ( celery_message = (
f"Sending diff {incremental_parent} -> {snapshot_name} for {rbd_name}" f"Sending diff of {rbd_name}@{incremental_parent} {snapshot_name}"
) )
else: else:
# Full image transfer # Full image transfer
@ -3481,17 +3468,8 @@ def vm_worker_send_snapshot(
total=total_stages, total=total_stages,
) )
send_headers = {
"X-Api-Key": destination_api_key,
"Content-Type": "application/octet-stream",
"Transfer-Encoding": None, # Disable chunked transfer encoding
}
if incremental_parent is not None: if incremental_parent is not None:
# Createa single session to reuse connections # Createa single session to reuse connections
session = requests.Session()
executor = ThreadPoolExecutor(max_workers=8)
send_params = { send_params = {
"pool": pool, "pool": pool,
"volume": volume, "volume": volume,
@ -3499,6 +3477,13 @@ def vm_worker_send_snapshot(
"source_snapshot": incremental_parent, "source_snapshot": incremental_parent,
} }
session.params.update(send_params)
# Send 32 objects (128MB) at once
send_max_objects = 32
batch_size_mb = 4 * send_max_objects
batch_size = batch_size_mb * 1024 * 1024
total_chunks = 0 total_chunks = 0
def diff_cb_count(offset, length, exists): def diff_cb_count(offset, length, exists):
@ -3507,33 +3492,63 @@ def vm_worker_send_snapshot(
total_chunks += 1 total_chunks += 1
current_chunk = 0 current_chunk = 0
buffer = list()
buffer_size = 0
last_chunk_time = time.time()
def send_block(block): def send_batch_multipart(buffer):
response = session.put( nonlocal last_chunk_time
f"{destination_api_uri}/vm/{vm_name}/snapshot/receive/block", files = {}
timeout=destination_api_timeout, for i in range(len(buffer)):
headers=send_headers, files[f"object_{i}"] = (
params=send_params, f"object_{i}",
data=block, buffer[i],
verify=destination_api_verify_ssl, "application/octet-stream",
)
try:
response = session.put(
f"{destination_api_uri}/vm/{vm_name}/snapshot/receive/block",
files=files,
stream=True,
)
response.raise_for_status()
except Exception as e:
fail(
celery,
f"Failed to send diff batch ({e}): {response.json()['message']}",
)
return False
current_chunk_time = time.time()
chunk_time = current_chunk_time - last_chunk_time
last_chunk_time = current_chunk_time
chunk_speed = round(batch_size_mb / chunk_time, 1)
update(
celery,
celery_message + f" ({chunk_speed} MB/s)",
current=current_stage,
total=total_stages,
) )
response.raise_for_status()
def add_block_to_multipart(buffer, offset, length, data):
part_data = (
offset.to_bytes(8, "big") + length.to_bytes(8, "big") + data
) # Add header and data
buffer.append(part_data)
def diff_cb_send(offset, length, exists): def diff_cb_send(offset, length, exists):
nonlocal current_chunk nonlocal current_chunk, buffer, buffer_size
if exists: if exists:
# Read the data for the current block
data = image.read(offset, length) data = image.read(offset, length)
block = offset.to_bytes(8, "big") + length.to_bytes(8, "big") + data # Add the block to the multipart buffer
add_block_to_multipart(buffer, offset, length, data)
executor.submit(send_block, block)
current_chunk += 1 current_chunk += 1
buffer_size += len(data)
update( if buffer_size >= batch_size:
celery, send_batch_multipart(buffer)
celery_message + f" ({current_chunk}/{total_chunks} objects)", buffer.clear() # Clear the buffer after sending
current=current_stage, buffer_size = 0 # Reset buffer size
total=total_stages,
)
try: try:
image.set_snap(snapshot_name) image.set_snap(snapshot_name)
@ -3543,6 +3558,11 @@ def vm_worker_send_snapshot(
image.diff_iterate( image.diff_iterate(
0, size, incremental_parent, diff_cb_send, whole_object=True 0, size, incremental_parent, diff_cb_send, whole_object=True
) )
if buffer:
send_batch_multipart(buffer)
buffer.clear() # Clear the buffer after sending
buffer_size = 0 # Reset buffer size
except Exception: except Exception:
fail( fail(
celery, celery,
@ -3583,13 +3603,11 @@ def vm_worker_send_snapshot(
} }
try: try:
response = requests.post( response = session.post(
f"{destination_api_uri}/vm/{vm_name}/snapshot/receive/block", f"{destination_api_uri}/vm/{vm_name}/snapshot/receive/block",
timeout=destination_api_timeout, headers={"Content-Type": "application/octet-stream"},
headers=send_headers,
params=send_params, params=send_params,
data=full_chunker(), data=full_chunker(),
verify=destination_api_verify_ssl,
) )
response.raise_for_status() response.raise_for_status()
except Exception: except Exception:
@ -3609,12 +3627,9 @@ def vm_worker_send_snapshot(
"snapshot": snapshot_name, "snapshot": snapshot_name,
} }
try: try:
response = requests.patch( response = session.patch(
f"{destination_api_uri}/vm/{vm_name}/snapshot/receive/block", f"{destination_api_uri}/vm/{vm_name}/snapshot/receive/block",
timeout=destination_api_timeout,
headers=send_headers,
params=send_params, params=send_params,
verify=destination_api_verify_ssl,
) )
response.raise_for_status() response.raise_for_status()
except Exception: except Exception: