mirror of
https://github.com/yt-dlp/yt-dlp
synced 2024-12-27 21:59:17 +01:00
[utils] Make JSON file writes atomic (Fixes #3549)
This commit is contained in:
parent
3b95347bb6
commit
181c8655c7
1 changed files with 30 additions and 11 deletions
|
@ -24,6 +24,7 @@ import socket
|
||||||
import struct
|
import struct
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
import tempfile
|
||||||
import traceback
|
import traceback
|
||||||
import xml.etree.ElementTree
|
import xml.etree.ElementTree
|
||||||
import zlib
|
import zlib
|
||||||
|
@ -228,18 +229,36 @@ else:
|
||||||
assert type(s) == type(u'')
|
assert type(s) == type(u'')
|
||||||
print(s)
|
print(s)
|
||||||
|
|
||||||
# In Python 2.x, json.dump expects a bytestream.
|
|
||||||
# In Python 3.x, it writes to a character stream
|
|
||||||
if sys.version_info < (3,0):
|
|
||||||
def write_json_file(obj, fn):
|
|
||||||
with open(fn, 'wb') as f:
|
|
||||||
json.dump(obj, f)
|
|
||||||
else:
|
|
||||||
def write_json_file(obj, fn):
|
|
||||||
with open(fn, 'w', encoding='utf-8') as f:
|
|
||||||
json.dump(obj, f)
|
|
||||||
|
|
||||||
if sys.version_info >= (2,7):
|
def write_json_file(obj, fn):
|
||||||
|
""" Encode obj as JSON and write it to fn, atomically """
|
||||||
|
|
||||||
|
# In Python 2.x, json.dump expects a bytestream.
|
||||||
|
# In Python 3.x, it writes to a character stream
|
||||||
|
if sys.version_info < (3, 0):
|
||||||
|
mode = 'wb'
|
||||||
|
encoding = None
|
||||||
|
else:
|
||||||
|
mode = 'w'
|
||||||
|
encoding = 'utf-8'
|
||||||
|
tf = tempfile.NamedTemporaryFile(
|
||||||
|
suffix='.tmp', prefix=os.path.basename(fn) + '.',
|
||||||
|
dir=os.path.dirname(fn),
|
||||||
|
delete=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
with tf:
|
||||||
|
json.dump(obj, tf)
|
||||||
|
os.rename(tf.name, fn)
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
os.remove(tf.name)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
if sys.version_info >= (2, 7):
|
||||||
def find_xpath_attr(node, xpath, key, val):
|
def find_xpath_attr(node, xpath, key, val):
|
||||||
""" Find the xpath xpath[@key=val] """
|
""" Find the xpath xpath[@key=val] """
|
||||||
assert re.match(r'^[a-zA-Z-]+$', key)
|
assert re.match(r'^[a-zA-Z-]+$', key)
|
||||||
|
|
Loading…
Reference in a new issue